jax.scipy.ndimage.map_coordinates#

jax.scipy.ndimage.map_coordinates(input, coordinates, order, mode='constant', cval=0.0)[source]#

Map the input array to new coordinates using interpolation.

JAX implementation of scipy.ndimage.map_coordinates()

Given an input array and a set of coordinates, this function returns the interpolated values of the input array at those coordinates.

Parameters:
  • input (Array | ndarray | bool | number | bool | int | float | complex) – N-dimensional input array from which values are interpolated.

  • coordinates (Sequence[Array | ndarray | bool | number | bool | int | float | complex]) – length-N sequence of arrays specifying the coordinates at which to evaluate the interpolated values

  • order (int) –

    The order of interpolation. JAX supports the following:

    • 0: Nearest-neighbor

    • 1: Linear

  • mode (str) – Points outside the boundaries of the input are filled according to the given mode. JAX supports one of ('constant', 'nearest', 'mirror', 'wrap', 'reflect'). Note the 'wrap' mode in JAX behaves as 'grid-wrap' mode in SciPy, and 'constant' mode in JAX behaves as 'grid-constant' mode in SciPy. This discrepancy was caused by a former bug in those modes in SciPy (scipy/scipy#2640), which was first fixed in JAX by changing the behavior of the existing modes, and later on fixed in SciPy, by adding modes with new names, rather than fixing the existing ones, for backwards compatibility reasons. Default is ‘constant’.

  • cval (Array | ndarray | bool | number | bool | int | float | complex) – Value used for points outside the boundaries of the input if mode='constant' Default is 0.0.

Returns:

The interpolated values at the specified coordinates.

Examples

>>> input = jnp.arange(12.0).reshape(3, 4)
>>> input
Array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11.]], dtype=float32)
>>> coordinates = [jnp.array([0.5, 1.5]),
...                jnp.array([1.5, 2.5])]
>>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1)
Array([3.5, 8.5], dtype=float32)

Note

Interpolation near boundaries differs from the scipy function, because JAX fixed an outstanding bug; see jax-ml/jax#11097. This function interprets the mode argument as documented by SciPy, but not as implemented by SciPy.