jax.scipy.ndimage.map_coordinates#
- jax.scipy.ndimage.map_coordinates(input, coordinates, order, mode='constant', cval=0.0)[source]#
Map the input array to new coordinates using interpolation.
JAX implementation of
scipy.ndimage.map_coordinates()
Given an input array and a set of coordinates, this function returns the interpolated values of the input array at those coordinates.
- Parameters:
input (Array | ndarray | bool | number | bool | int | float | complex) – N-dimensional input array from which values are interpolated.
coordinates (Sequence[Array | ndarray | bool | number | bool | int | float | complex]) – length-N sequence of arrays specifying the coordinates at which to evaluate the interpolated values
order (int) –
The order of interpolation. JAX supports the following:
0: Nearest-neighbor
1: Linear
mode (str) – Points outside the boundaries of the input are filled according to the given mode. JAX supports one of
('constant', 'nearest', 'mirror', 'wrap', 'reflect')
. Note the'wrap'
mode in JAX behaves as'grid-wrap'
mode in SciPy, and'constant'
mode in JAX behaves as'grid-constant'
mode in SciPy. This discrepancy was caused by a former bug in those modes in SciPy (scipy/scipy#2640), which was first fixed in JAX by changing the behavior of the existing modes, and later on fixed in SciPy, by adding modes with new names, rather than fixing the existing ones, for backwards compatibility reasons. Default is ‘constant’.cval (Array | ndarray | bool | number | bool | int | float | complex) – Value used for points outside the boundaries of the input if
mode='constant'
Default is 0.0.
- Returns:
The interpolated values at the specified coordinates.
Examples
>>> input = jnp.arange(12.0).reshape(3, 4) >>> input Array([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], dtype=float32) >>> coordinates = [jnp.array([0.5, 1.5]), ... jnp.array([1.5, 2.5])] >>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1) Array([3.5, 8.5], dtype=float32)
Note
Interpolation near boundaries differs from the scipy function, because JAX fixed an outstanding bug; see jax-ml/jax#11097. This function interprets the
mode
argument as documented by SciPy, but not as implemented by SciPy.