jax.lax.dynamic_index_in_dim#

jax.lax.dynamic_index_in_dim(operand, index, axis=0, keepdims=True)[source]#

Convenience wrapper around dynamic_slice to perform int indexing.

Parameters
  • operand (Array) –

  • index (Array) –

  • axis (int) –

  • keepdims (bool) –

Return type

Array