jax.ops.index

jax.ops.index = <jax._src.ops.scatter._Indexable object>

Helper object for building indexes for indexed update functions.

This is a singleton object that overrides the __getitem__ method to return the index it is passed.

>>> jax.ops.index[1:2, 3, None, ..., ::2]
(slice(1, 2, None), 3, None, Ellipsis, slice(None, None, 2))