jax.scipy.ndimage.map_coordinates#
- jax.scipy.ndimage.map_coordinates(input, coordinates, order, mode='constant', cval=0.0)[source]#
Map the input array to new coordinates using interpolation.
JAX implementation of
scipy.ndimage.map_coordinates()
Given an input array and a set of coordinates, this function returns the interpolated values of the input array at those coordinates.
- Parameters:
input (Array | ndarray | bool_ | number | bool | int | float | complex) – N-dimensional input array from which values are interpolated.
coordinates (Sequence[Array | ndarray | bool_ | number | bool | int | float | complex]) – length-N sequence of arrays specifying the coordinates at which to evaluate the interpolated values
order (int) –
The order of interpolation. JAX supports the following:
0: Nearest-neighbor
1: Linear
mode (str) – Points outside the boundaries of the input are filled according to the given mode. JAX supports one of
('constant', 'nearest', 'mirror', 'wrap', 'reflect')
. Default is ‘constant’.cval (Array | ndarray | bool_ | number | bool | int | float | complex) – Value used for points outside the boundaries of the input if
mode='constant'
Default is 0.0.
- Returns:
The interpolated values at the specified coordinates.
Examples
>>> input = jnp.arange(12.0).reshape(3, 4) >>> input Array([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], dtype=float32) >>> coordinates = [jnp.array([0.5, 1.5]), ... jnp.array([1.5, 2.5])] >>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1) Array([3.5, 8.5], dtype=float32)
Note
Interpolation near boundaries differs from the scipy function, because JAX fixed an outstanding bug; see google/jax#11097. This function interprets the
mode
argument as documented by SciPy, but not as implemented by SciPy.