jax.lax.dynamic_update_slice

jax.lax.dynamic_update_slice(operand, update, start_indices)[source]

Wraps XLA’s DynamicUpdateSlice operator.

Parameters
  • operand (Any) – an array to slice.

  • update (Any) – an array containing the new values to write onto operand.

  • start_indices (Any) – a list of scalar indices, one per dimension.

Return type

Any

Returns

An array containing the slice.

Examples

Here is an example of updating a one-dimensional slice update:

>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice(x, y, (2,))
DeviceArray([0., 0., 1., 1., 1., 0.], dtype=float32)

If the update slice is too large to fit in the array, the start index will be adjusted to make it fit

>>> dynamic_update_slice(x, y, (3,))
DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice(x, y, (5,))
DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32)

Here is an example of a two-dimensional slice update:

>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 2))
>>> dynamic_update_slice(x, y, (1, 2))
DeviceArray([[0., 0., 0., 0.],
             [0., 0., 1., 1.],
             [0., 0., 1., 1.],
             [0., 0., 0., 0.]], dtype=float32)