jax.lax.index_take#

jax.lax.index_take(src, idxs, axes)[source]#
Parameters:
Return type:

Array