jax.numpy.dsplit

Contents

jax.numpy.dsplit#

jax.numpy.dsplit(ary, indices_or_sections)[source]#

Split array into multiple sub-arrays along the 3rd axis (depth).

LAX-backend implementation of numpy.dsplit().

Original docstring below.

Please refer to the split documentation. dsplit is equivalent to split with axis=2, the array is always split along the third axis provided the array dimension is greater than or equal to 3.

Parameters:
  • ary (ArrayLike)

  • indices_or_sections (int | Sequence[int] | ArrayLike)

Return type:

list[Array]