jax.lax.index_takeΒΆ

jax.lax.index_take(src, idxs, axes)[source]ΒΆ
Parameters
Return type

Any