jax.lax.gather¶

jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes)[source]¶

Gather operator.

Wraps XLA’s Gather operator.

The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer Numpy-style indexing (e.g., x[:, (1,4,7), …]), rather than using gather directly.

Parameters
  • operand (Any) – an array from which slices should be taken

  • start_indices (Any) – the indices at which slices should be taken

  • dimension_numbers (GatherDimensionNumbers) – a lax.GatherDimensionNumbers object that describes how dimensions of operand, start_indices and the output relate.

  • slice_sizes (Sequence[Union[int, Any]]) – the size of each slice. Must be a sequence of non-negative integers with length equal to ndim(operand).

Return type

Any

Returns

An array containing the gather output.