jax.lax.dynamic_update_slice_in_dim#
- jax.lax.dynamic_update_slice_in_dim(operand, update, start_index, axis)[source]#
Convenience wrapper around
dynamic_update_slice()
to update a slice in a singleaxis
.- Parameters:
operand (Array | np.ndarray) – an array to slice.
update (ArrayLike) – an array containing the new values to write onto operand.
start_index (ArrayLike) – a single scalar index
axis (int) – the axis of the update.
- Return type:
Array
- Returns:
The updated array
Examples
>>> x = jnp.zeros(6) >>> y = jnp.ones(3) >>> dynamic_update_slice_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 1., 1., 0.], dtype=float32)
If the update slice is too large to fit in the array, the start index will be adjusted to make it fit:
>>> dynamic_update_slice_in_dim(x, y, 3, axis=0) Array([0., 0., 0., 1., 1., 1.], dtype=float32) >>> dynamic_update_slice_in_dim(x, y, 5, axis=0) Array([0., 0., 0., 1., 1., 1.], dtype=float32)
Here is an example of a two-dimensional slice update:
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones((2, 4)) >>> dynamic_update_slice_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 1.], [1., 1., 1., 1.], [0., 0., 0., 0.]], dtype=float32)
Note that the shape of the additional axes in
update
need not match the associated dimensions of theoperand
:>>> y = jnp.ones((2, 3)) >>> dynamic_update_slice_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.]], dtype=float32)