jax.ops.index

jax.ops.index = <jax._src.ops.scatter._Indexable object>

Helper object for building indexes for indexed update functions.

Deprecated since version 0.2.22: Prefer the use of jax.numpy.ndarray.at. If an explicit index is needed, use jax.numpy.index_exp().

This is a singleton object that overrides the __getitem__ method to return the index it is passed.

>>> jax.ops.index[1:2, 3, None, ..., ::2]
(slice(1, 2, None), 3, None, Ellipsis, slice(None, None, 2))