- jax.lax.scatter(operand, scatter_indices, updates, dimension_numbers, *, indices_are_sorted=False, unique_indices=False, mode=None)#
Wraps XLA’s Scatter operator, where updates replace values from operand.
If multiple updates are performed to the same index of operand, they may be applied in any order.
The semantics of scatter are complicated, and its API might change in the future. For most use cases, you should prefer the
jax.numpy.ndarray.atproperty on JAX arrays which uses the familiar NumPy indexing syntax.
ScatterDimensionNumbers) – a lax.ScatterDimensionNumbers object that describes how dimensions of operand, start_indices, updates and the output relate.
bool) – whether scatter_indices is known to be sorted. If true, may improve performance on some backends.
bool) – whether the elements to be updated in
operandare guaranteed to not overlap with each other. If true, may improve performance on some backends. JAX does not check this promise: if the updated elements overlap when
Truethe behavior is undefined.
None]) – how to handle indices that are out of bounds: when set to ‘clip’, indices are clamped so that the slice is within bounds, and when set to ‘fill’ or ‘drop’ out-of-bounds updates are dropped. The behavior for out-of-bounds indices when set to ‘promise_in_bounds’ is implementation-defined.
- Return type
An array containing the sum of operand and the scattered updates.