jax.numpy.unstack

Contents

jax.numpy.unstack#

jax.numpy.unstack(x, /, *, axis=0)[source]#
Parameters:
  • x (jax.typing.ArrayLike)

  • axis (int)

Return type:

tuple[Array, …]