jax.numpy.concatenate

jax.numpy.concatenate(arrays, axis=0)[source]

Join a sequence of arrays along an existing axis.

LAX-backend implementation of concatenate().

Original docstring below.

Parameters

axis (int, optional) – The axis along which the arrays will be joined. If axis is None, arrays are flattened before use. Default is 0.

Returns

res – The concatenated array.

Return type

ndarray