jax.numpy.vsplit

Contents

jax.numpy.vsplit#

jax.numpy.vsplit(ary, indices_or_sections)[source]#

Split an array into multiple sub-arrays vertically (row-wise).

LAX-backend implementation of numpy.vsplit().

Original docstring below.

Please refer to the split documentation. vsplit is equivalent to split with axis=0 (default), the array is always split along the first axis regardless of the array dimension.

Parameters:
  • ary (ArrayLike)

  • indices_or_sections (int | Sequence[int] | ArrayLike)

Return type:

list[Array]