jax.numpy.splitΒΆ
-
jax.numpy.
split
(ary, indices_or_sections, axis=0)[source]ΒΆ Split an array into multiple sub-arrays as views into ary.
LAX-backend implementation of
split()
. Original docstring below.- Parameters
ary (ndarray) β Array to be divided into sub-arrays.
indices_or_sections (int or 1-D array) β If indices_or_sections is an integer, N, the array will be divided into N equal arrays along axis. If such a split is not possible, an error is raised.
axis (int, optional) β The axis along which to split, default is 0.
- Returns
sub-arrays β A list of sub-arrays as views into ary.
- Return type
list of ndarrays
- Raises
ValueError β If indices_or_sections is given as an integer, but a split does not result in equal division.
See also
array_split()
Split an array into multiple sub-arrays of equal or near-equal size. Does not raise an exception if an equal division cannot be made.
hsplit()
Split array into multiple sub-arrays horizontally (column-wise).
vsplit()
Split array into multiple sub-arrays vertically (row wise).
dsplit()
Split array into multiple sub-arrays along the 3rd axis (depth).
concatenate()
Join a sequence of arrays along an existing axis.
stack()
Join a sequence of arrays along a new axis.
hstack()
Stack arrays in sequence horizontally (column wise).
vstack()
Stack arrays in sequence vertically (row wise).
dstack()
Stack arrays in sequence depth wise (along third dimension).
Examples
>>> x = np.arange(9.0) >>> np.split(x, 3) [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]
>>> x = np.arange(8.0) >>> np.split(x, [3, 5, 6, 10]) [array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.]), array([], dtype=float64)]