jax.numpy.concat

Contents

jax.numpy.concat#

jax.numpy.concat(arrays, /, *, axis=0)[source]#

Join a sequence of arrays along an existing axis.

LAX-backend implementation of numpy.concatenate().

Original docstring below.

Parameters:
  • axis (int, optional) – The axis along which the arrays will be joined. If axis is None, arrays are flattened before use. Default is 0.

  • arrays (Sequence[jax.typing.ArrayLike])

Returns:

res – The concatenated array.

Return type:

ndarray