jax.Array.take

Contents

jax.Array.take#

abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#

Take elements from an array.

Refer to jax.numpy.take() for full documentation.

Parameters:
  • self (Array)

  • indices (ArrayLike)

  • axis (int | None)

  • out (None)

  • mode (str | None)

  • unique_indices (bool)

  • indices_are_sorted (bool)

  • fill_value (StaticScalar | None)

Return type:

Array