jax.lax.scatter_add

jax.lax.scatter_add(operand, scatter_indices, updates, dimension_numbers, *, indices_are_sorted=False, unique_indices=False, mode=None)[source]

Scatter-add operator.

Wraps XLA’s Scatter operator, where addition is used to combine updates and values from operand.

The semantics of scatter are complicated, and its API might change in the future. For most use cases, you should prefer the jax.numpy.ndarray.at property on JAX arrays which uses the familiar NumPy indexing syntax.

Parameters
  • operand (Any) – an array to which the scatter should be applied

  • scatter_indices (Any) – an array that gives the indices in operand to which each update in updates should be applied.

  • updates (Any) – the updates that should be scattered onto operand.

  • dimension_numbers (ScatterDimensionNumbers) – a lax.ScatterDimensionNumbers object that describes how dimensions of operand, start_indices, updates and the output relate.

  • indices_are_sorted (bool) – whether scatter_indices is known to be sorted. If true, may improve performance on some backends.

  • unique_indices (bool) – whether the indices to be updated in operand are guaranteed to not overlap with each other. If true, may improve performance on some backends.

  • mode (Union[str, GatherScatterMode, None]) – how to handle indices that are out of bounds: when set to ‘clip’, indices are clamped so that the slice is within bounds, and when set to ‘fill’ or ‘drop’ out-of-bounds updates are dropped. The behavior for out-of-bounds indices when set to ‘promise_in_bounds’ is implementation-defined.

Return type

Any

Returns

An array containing the sum of operand and the scattered updates.