jax.scipy.ndimage.map_coordinates

jax.scipy.ndimage.map_coordinates#

jax.scipy.ndimage.map_coordinates(input, coordinates, order, mode='constant', cval=0.0)[source]#

Map the input array to new coordinates using interpolation.

JAX implementation of scipy.ndimage.map_coordinates()

Given an input array and a set of coordinates, this function returns the interpolated values of the input array at those coordinates.

Parameters:
  • input (Array | ndarray | bool_ | number | bool | int | float | complex) – N-dimensional input array from which values are interpolated.

  • coordinates (Sequence[Array | ndarray | bool_ | number | bool | int | float | complex]) – length-N sequence of arrays specifying the coordinates at which to evaluate the interpolated values

  • order (int) –

    The order of interpolation. JAX supports the following:

    • 0: Nearest-neighbor

    • 1: Linear

  • mode (str) – Points outside the boundaries of the input are filled according to the given mode. JAX supports one of ('constant', 'nearest', 'mirror', 'wrap', 'reflect'). Default is ‘constant’.

  • cval (Array | ndarray | bool_ | number | bool | int | float | complex) – Value used for points outside the boundaries of the input if mode='constant' Default is 0.0.

Returns:

The interpolated values at the specified coordinates.

Examples

>>> input = jnp.arange(12.0).reshape(3, 4)
>>> input
Array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11.]], dtype=float32)
>>> coordinates = [jnp.array([0.5, 1.5]),
...                jnp.array([1.5, 2.5])]
>>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1)
Array([3.5, 8.5], dtype=float32)

Note

Interpolation near boundaries differs from the scipy function, because JAX fixed an outstanding bug; see google/jax#11097. This function interprets the mode argument as documented by SciPy, but not as implemented by SciPy.