jax.Array.take#
- abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#
Take elements from an array.
Refer to
jax.numpy.take()
for full documentation.
Take elements from an array.
Refer to jax.numpy.take()
for full documentation.