jax.ops.index_minΒΆ
-
jax.ops.
index_min
(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]ΒΆ Pure equivalent of
x[idx] = minimum(x[idx], y)
.Returns the value of x that would result from the NumPy-style
indexed assignment
:x[idx] = minimum(x[idx], y)
Note the index_min operator is pure; x itself is not modified, instead the new value that x would have taken is returned.
Unlike the NumPy code
x[idx] = minimum(x[idx], y)
, if multiple indices refer to the same location the final value will be the overall min. (NumPy would only look at the last update, rather than all of the updates.)- Parameters
x β an array with the values to be updated.
idx β a Numpy-style index, consisting of None, integers, slice objects, ellipses, ndarrays with integer dtypes, or a tuple of the above. A convenient syntactic sugar for forming indices is via the
jax.ops.index
object.y β the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].
indices_are_sorted β whether idx is known to be sorted
unique_indices β whether idx is known to be free of duplicates
- Returns
An array.
>>> x = jax.numpy.ones((5, 6)) >>> jax.ops.index_minimum(x, jax.ops.index[2:4, 3:], 0.) array([[1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 0., 0., 0.], [1., 1., 1., 0., 0., 0.], [1., 1., 1., 1., 1., 1.]], dtype=float32)