jax.lax.dynamic_update_slice#
- jax.lax.dynamic_update_slice(operand, update, start_indices)[source]#
Wraps XLA’s DynamicUpdateSlice operator.
- Parameters:
operand (Array | np.ndarray) – an array to slice.
update (ArrayLike) – an array containing the new values to write onto operand.
start_indices (Array | Sequence[ArrayLike]) – a list of scalar indices, one per dimension.
- Return type:
Array
- Returns:
An array containing the slice.
Examples
Here is an example of updating a one-dimensional slice update:
>>> x = jnp.zeros(6) >>> y = jnp.ones(3) >>> dynamic_update_slice(x, y, (2,)) Array([0., 0., 1., 1., 1., 0.], dtype=float32)
If the update slice is too large to fit in the array, the start index will be adjusted to make it fit
>>> dynamic_update_slice(x, y, (3,)) Array([0., 0., 0., 1., 1., 1.], dtype=float32) >>> dynamic_update_slice(x, y, (5,)) Array([0., 0., 0., 1., 1., 1.], dtype=float32)
Here is an example of a two-dimensional slice update:
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones((2, 2)) >>> dynamic_update_slice(x, y, (1, 2)) Array([[0., 0., 0., 0.], [0., 0., 1., 1.], [0., 0., 1., 1.], [0., 0., 0., 0.]], dtype=float32)
See also
lax.dynamic_update_index_in_dim
lax.dynamic_update_slice_in_dim