- jax.numpy.split(ary, indices_or_sections, axis=0)¶
Split an array into multiple sub-arrays as views into ary.
LAX-backend implementation of
The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
ary (ndarray) – Array to be divided into sub-arrays.
indices_or_sections (int or 1-D array) –
If indices_or_sections is an integer, N, the array will be divided into N equal arrays along axis. If such a split is not possible, an error is raised.
If indices_or_sections is a 1-D array of sorted integers, the entries indicate where along axis the array is split. For example,
[2, 3]would, for
axis=0, result in
If an index exceeds the dimension of the array along axis, an empty sub-array is returned correspondingly.
axis (int, optional) – The axis along which to split, default is 0.
sub-arrays – A list of sub-arrays as views into ary.
- Return type
list of ndarrays