jax.numpy.stack#
- jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[source]#
Join a sequence of arrays along a new axis.
LAX-backend implementation of
numpy.stack()
.Original docstring below.
The
axis
parameter specifies the index of the new axis in the dimensions of the result. For example, ifaxis=0
it will be the first dimension and ifaxis=-1
it will be the last dimension.New in version 1.10.0.
- Parameters
arrays (sequence of array_like) – Each array must have the same shape.
axis (int, optional) – The axis in the result array along which the input arrays are stacked.
dtype (str or dtype) – If provided, the destination array will have this dtype. Cannot be provided together with out.
out (
None
) –
- Returns
stacked – The stacked array has one more dimension than the input arrays.
- Return type
ndarray