jax.numpy.dsplit

Contents

jax.numpy.dsplit#

jax.numpy.dsplit(ary, indices_or_sections)[source]#

Split an array into sub-arrays depth-wise.

JAX implementation of numpy.dsplit().

Refer to the documentation of jax.numpy.split() for details. dsplit is equivalent to split with axis=2.

Examples

>>> x = jnp.arange(12).reshape(3, 1, 4)
>>> print(x)
[[[ 0  1  2  3]]

 [[ 4  5  6  7]]

 [[ 8  9 10 11]]]
>>> x1, x2 = jnp.dsplit(x, 2)
>>> print(x1)
[[[0 1]]

 [[4 5]]

 [[8 9]]]
>>> print(x2)
[[[ 2  3]]

 [[ 6  7]]

 [[10 11]]]

See also

Parameters:
  • ary (ArrayLike)

  • indices_or_sections (int | Sequence[int] | ArrayLike)

Return type:

list[Array]