jax.numpy.block

Contents

jax.numpy.block#

jax.numpy.block(arrays)[source]#

Assemble an nd-array from nested lists of blocks.

LAX-backend implementation of numpy.block().

Original docstring below.

Blocks in the innermost lists are concatenated (see concatenate) along the last dimension (-1), then these are concatenated along the second-last dimension (-2), and so on until the outermost list is reached.

Blocks can be of any dimension, but will not be broadcasted using the normal rules. Instead, leading axes of size 1 are inserted, to make block.ndim the same for all blocks. This is primarily useful for working with scalars, and means that code like np.block([v, 1]) is valid, where v.ndim == 1.

When the nested list is two levels deep, this allows block matrices to be constructed from their components.

Added in version 1.13.0.

Parameters:

arrays (nested list of array_like or scalars (but not tuples)) –

If passed a single ndarray or scalar (a nested list of depth 0), this is returned unmodified (and not copied).

Elements shapes must match along the appropriate axes (without broadcasting), but leading 1s will be prepended to the shape as necessary to make the dimensions match.

Returns:

block_array – The array assembled from the given blocks.

The dimensionality of the output is equal to the greatest of: * the dimensionality of all the inputs * the depth to which the input list is nested

Return type:

ndarray