jax.lax.dynamic_update_index_in_dimΒΆ

jax.lax.dynamic_update_index_in_dim(operand, update, index, axis)[source]ΒΆ

Convenience wrapper around dynamic_update_slice() to update a slice of size 1 in a single axis.

Parameters
  • operand (Any) –

  • update (Any) –

  • index (Any) –

  • axis (int) –

Return type

Any