jax.numpy.array_split

Contents

jax.numpy.array_split#

jax.numpy.array_split(ary, indices_or_sections, axis=0)[source]#

Split an array into multiple sub-arrays.

LAX-backend implementation of numpy.array_split().

Original docstring below.

Please refer to the split documentation. The only difference between these functions is that array_split allows indices_or_sections to be an integer that does not equally divide the axis. For an array of length l that should be split into n sections, it returns l % n sub-arrays of size l//n + 1 and the rest of size l//n.

Parameters:
  • ary (ArrayLike)

  • indices_or_sections (int | Sequence[int] | ArrayLike)

  • axis (int)

Return type:

list[Array]