jax.ops.index_max(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]ΒΆ

Pure equivalent of x[idx] = maximum(x[idx], y).

Deprecated since version 0.2.22: Prefer the use of jax.numpy.ndarray.at.

Returns the value of x that would result from the NumPy-style indexed assignment:

x[idx] = maximum(x[idx], y)

Note the index_max operator is pure; x itself is not modified, instead the new value that x would have taken is returned.

Unlike the NumPy code x[idx] = maximum(x[idx], y), if multiple indices refer to the same location the final value will be the overall max. (NumPy would only look at the last update, rather than all of the updates.)

  • x (Any) – an array with the values to be updated.

  • idx (Union[None, int, slice, Sequence[int], Any, Tuple[Union[None, int, slice, Sequence[int], Any], …]]) – a Numpy-style index, consisting of None, integers, slice objects, ellipses, ndarrays with integer dtypes, or a tuple of the above. A convenient syntactic sugar for forming indices is via the jax.ops.index object.

  • y (Union[Any, complex, float, int, number]) – the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].

  • indices_are_sorted (bool) – whether idx is known to be sorted

  • unique_indices (bool) – whether idx is known to be free of duplicates

Return type



An array.

>>> x = jax.numpy.ones((5, 6))
>>> jax.ops.index_max(x, jnp.index_exp[2:4, 3:], 6.)
DeviceArray([[1., 1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1., 1.],
             [1., 1., 1., 6., 6., 6.],
             [1., 1., 1., 6., 6., 6.],
             [1., 1., 1., 1., 1., 1.]], dtype=float32)