jax.lax.index_take# jax.lax.index_take(src, idxs, axes)[source]# Parameters: src (Array) idxs (Array) axes (Sequence[int]) Return type: Array