jax.lax.gather#
- jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)[source]#
Gather operator.
Wraps XLA’s Gather operator.
The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer Numpy-style indexing (e.g., x[:, (1,4,7), …]), rather than using gather directly.
- Parameters:
operand (jax.typing.ArrayLike) – an array from which slices should be taken
start_indices (jax.typing.ArrayLike) – the indices at which slices should be taken
dimension_numbers (GatherDimensionNumbers) – a lax.GatherDimensionNumbers object that describes how dimensions of operand, start_indices and the output relate.
slice_sizes (Sequence[int | Any]) – the size of each slice. Must be a sequence of non-negative integers with length equal to ndim(operand).
indices_are_sorted (bool) – whether indices is known to be sorted. If true, may improve performance on some backends.
unique_indices (bool) – whether the elements gathered from
operand
are guaranteed not to overlap with each other. IfTrue
, this may improve performance on some backends. JAX does not check this promise: if the elements overlap the behavior is undefined.mode (str | GatherScatterMode | None) – how to handle indices that are out of bounds: when set to
'clip'
, indices are clamped so that the slice is within bounds, and when set to'fill'
or'drop'
gather returns a slice full offill_value
for the affected slice. The behavior for out-of-bounds indices when set to'promise_in_bounds'
is implementation-defined.fill_value – the fill value to return for out-of-bounds slices when mode is
'fill'
. Ignored otherwise. Defaults toNaN
for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, andTrue
for booleans.
- Returns:
An array containing the gather output.
- Return type: