jax.lax.index_take#

jax.lax.index_take(src, idxs, axes)[source]#
Parameters
  • src (Array) –

  • idxs (Array) –

  • axes (Sequence[int]) –

Return type

Array