jax.numpy.array_split

jax.numpy.array_split(ary, indices_or_sections, axis=0)[source]

Split an array into multiple sub-arrays.

LAX-backend implementation of array_split(). Original docstring below.

Pl]