jax.lax.dynamic_update_index_in_dim

jax.lax.dynamic_update_index_in_dim#

jax.lax.dynamic_update_index_in_dim(operand, update, index, axis)[source]#

Convenience wrapper around dynamic_update_slice() to update a slice of size 1 in a single axis.

Parameters:
  • operand (Array | ndarray) – an array to slice.

  • update (jax.typing.ArrayLike) – an array containing the new values to write onto operand.

  • index (jax.typing.ArrayLike) – a single scalar index

  • axis (int) – the axis of the update.

Returns:

The updated array

Return type:

Array

Examples

>>> x = jnp.zeros(6)
>>> y = 1.0
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)
>>> y = jnp.array([1.0])
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)

If the specified index is out of bounds, the index will be clipped to the valid range:

>>> dynamic_update_index_in_dim(x, y, 10, axis=0)
Array([0., 0., 0., 0., 0., 1.], dtype=float32)

Here is an example of a two-dimensional dynamic index update:

>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones(4)
>>> dynamic_update_index_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
      [1., 1., 1., 1.],
      [0., 0., 0., 0.],
      [0., 0., 0., 0.]], dtype=float32)

Note that the shape of the additional axes in update need not match the associated dimensions of the operand:

>>> y = jnp.ones((1, 3))
>>> dynamic_update_index_in_dim(x, y, 1, 0)
Array([[0., 0., 0., 0.],
       [1., 1., 1., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)