jax.lax.dynamic_update_index_in_dim#
- jax.lax.dynamic_update_index_in_dim(operand, update, index, axis)[source]#
Convenience wrapper around
dynamic_update_slice()
to update a slice of size 1 in a singleaxis
.- Parameters:
- Returns:
The updated array
- Return type:
Examples
>>> x = jnp.zeros(6) >>> y = 1.0 >>> dynamic_update_index_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 0., 0., 0.], dtype=float32)
>>> y = jnp.array([1.0]) >>> dynamic_update_index_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 0., 0., 0.], dtype=float32)
If the specified index is out of bounds, the index will be clipped to the valid range:
>>> dynamic_update_index_in_dim(x, y, 10, axis=0) Array([0., 0., 0., 0., 0., 1.], dtype=float32)
Here is an example of a two-dimensional dynamic index update:
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones(4) >>> dynamic_update_index_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32)
Note that the shape of the additional axes in
update
need not match the associated dimensions of theoperand
:>>> y = jnp.ones((1, 3)) >>> dynamic_update_index_in_dim(x, y, 1, 0) Array([[0., 0., 0., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32)