jax.numpy.splitΒΆ

jax.numpy.split(ary, indices_or_sections, axis=0)[source]ΒΆ

Split an array into multiple sub-arrays as views into ary.

LAX-backend implementation of split(). Original docstring below.

Parameters
  • ary (ndarray) – Array to be divided into sub-arrays.

  • indices_or_sections (int or 1-D array) – If indices_or_sections is an integer, N, the array will be divided into N equal arrays along axis. If such a split is not possible, an error is raised.

  • axis (int, optional) – The axis along which to split, default is 0.

Returns

sub-arrays – A list of sub-arrays as views into ary.

Return type

list of ndarrays

Raises

ValueError – If indices_or_sections is given as an integer, but a split does not result in equal division.

See also

array_split()

Split an array into multiple sub-arrays of equal or near-equal size. Does not raise an exception if an equal division cannot be made.

hsplit()

Split array into multiple sub-arrays horizontally (column-wise).

vsplit()

Split array into multiple sub-arrays vertically (row wise).

dsplit()

Split array into multiple sub-arrays along the 3rd axis (depth).

concatenate()

Join a sequence of arrays along an existing axis.

stack()

Join a sequence of arrays along a new axis.

hstack()

Stack arrays in sequence horizontally (column wise).

vstack()

Stack arrays in sequence vertically (row wise).

dstack()

Stack arrays in sequence depth wise (along third dimension).

Examples

>>> x = np.arange(9.0)
>>> np.split(x, 3)
[array([0.,  1.,  2.]), array([3.,  4.,  5.]), array([6.,  7.,  8.])]
>>> x = np.arange(8.0)
>>> np.split(x, [3, 5, 6, 10])
[array([0.,  1.,  2.]),
 array([3.,  4.]),
 array([5.]),
 array([6.,  7.]),
 array([], dtype=float64)]