jax.numpy.vsplit#
- jax.numpy.vsplit(ary, indices_or_sections)#
Split an array into multiple sub-arrays vertically (row-wise).
LAX-backend implementation of
numpy.vsplit()
.Original docstring below.
Please refer to the
split
documentation.vsplit
is equivalent tosplit
with axis=0 (default), the array is always split along the first axis regardless of the array dimension.