jax.numpy.concatenate#

jax.numpy.concatenate(arrays, axis=0, dtype=None)[source]#

Join a sequence of arrays along an existing axis.

LAX-backend implementation of numpy.concatenate().

Original docstring below.

Parameters
  • axis (int, optional) – The axis along which the arrays will be joined. If axis is None, arrays are flattened before use. Default is 0.

  • dtype (str or dtype) – If provided, the destination array will have this dtype. Cannot be provided together with out.

  • arrays (np.ndarray | Array | Sequence[ArrayLike]) –

Returns

res – The concatenated array.

Return type

ndarray