jax.numpy.hsplit#

jax.numpy.hsplit(ary, indices_or_sections)#

Split an array into multiple sub-arrays horizontally (column-wise).

LAX-backend implementation of numpy.hsplit().

Original docstring below.

Please refer to the split documentation. hsplit is equivalent to split with axis=1, the array is always split along the second axis except for 1-D arrays, where it is split at axis=0.

Parameters
Return type

List[Array]