jax.lax.gather

Contents

jax.lax.gather#

jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)[source]#

Gather operator.

Wraps XLA’s Gather operator.

The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer Numpy-style indexing (e.g., x[:, (1,4,7), …]), rather than using gather directly.

Parameters:
  • operand (jax.typing.ArrayLike) – an array from which slices should be taken

  • start_indices (jax.typing.ArrayLike) – the indices at which slices should be taken

  • dimension_numbers (GatherDimensionNumbers) – a lax.GatherDimensionNumbers object that describes how dimensions of operand, start_indices and the output relate.

  • slice_sizes (Sequence[int | Any]) – the size of each slice. Must be a sequence of non-negative integers with length equal to ndim(operand).

  • indices_are_sorted (bool) – whether indices is known to be sorted. If true, may improve performance on some backends.

  • unique_indices (bool) – whether the elements gathered from operand are guaranteed not to overlap with each other. If True, this may improve performance on some backends. JAX does not check this promise: if the elements overlap the behavior is undefined.

  • mode (str | GatherScatterMode | None) – how to handle indices that are out of bounds: when set to 'clip', indices are clamped so that the slice is within bounds, and when set to 'fill' or 'drop' gather returns a slice full of fill_value for the affected slice. The behavior for out-of-bounds indices when set to 'promise_in_bounds' is implementation-defined.

  • fill_value – the fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

Returns:

An array containing the gather output.

Return type:

Array