jax.numpy.stack

Contents

jax.numpy.stack#

jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[source]#

Join a sequence of arrays along a new axis.

LAX-backend implementation of numpy.stack().

Original docstring below.

The axis parameter specifies the index of the new axis in the dimensions of the result. For example, if axis=0 it will be the first dimension and if axis=-1 it will be the last dimension.

Added in version 1.10.0.

Parameters:
  • arrays (sequence of array_like) – Each array must have the same shape.

  • axis (int, optional) – The axis in the result array along which the input arrays are stacked.

  • dtype (str or dtype) – If provided, the destination array will have this dtype. Cannot be provided together with out.

  • out (None)

Returns:

stacked – The stacked array has one more dimension than the input arrays.

Return type:

ndarray