jax.lax.dynamic_index_in_dimΒΆ

jax.lax.dynamic_index_in_dim(operand, index, axis=0, keepdims=True)[source]ΒΆ

Convenience wrapper around dynamic_slice to perform int indexing.

Parameters
  • operand (Any) –

  • index (Any) –

  • axis (int) –

  • keepdims (bool) –

Return type

Any