jax.numpy.hsplit

jax.numpy.hsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays horizontally (column-wise).

LAX-backend implementation of hsplit().

Original docstring below.

Please refer to the split documentation. hsplit is equivalent to split with axis=1, the array is always split along the second axis regardless of the array dimension.