- jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)#
Wraps XLA’s Gather operator.
The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer Numpy-style indexing (e.g., x[:, (1,4,7), …]), rather than using gather directly.
GatherDimensionNumbers) – a lax.GatherDimensionNumbers object that describes how dimensions of operand, start_indices and the output relate.
bool) – whether indices is known to be sorted. If true, may improve performance on some backends.
bool) – whether the elements gathered from
operandare guaranteed not to overlap with each other. If
True, this may improve performance on some backends. JAX does not check this promise: if the elements overlap the behavior is undefined.
None]) – how to handle indices that are out of bounds: when set to
'clip', indices are clamped so that the slice is within bounds, and when set to
'drop'gather returns a slice full of
fill_valuefor the affected slice. The behavior for out-of-bounds indices when set to
fill_value – the fill value to return for out-of-bounds slices when mode is
'fill'. Ignored otherwise. Defaults to
NaNfor inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and
- Return type
An array containing the gather output.