jax.numpy.block

jax.numpy.block(arrays)[source]

Assemble an nd-array from nested lists of blocks.

LAX-backend implementation of block(). Original docstring below.

Blocks in the innermost lists are concatenated (see concatenate) along the last dimension (-1), then these are concatenated along the second-last dimension (-2), and so on until the outermost list is reached.

Blocks can be of any dimension, but will not be broadcasted using the normal rules. Instead, leading axes of size 1 are inserted, to make block.ndim the same for all blocks. This is primarily useful for working with scalars, and means that code like np.block([v, 1]) is valid, where v.ndim == 1.

When the nested list is two levels deep, this allows block matrices to be constructed from their components.

New in version 1.13.0.

Returns

  • block_array (ndarray) – The array assembled from the given blocks.

  • The dimensionality of the output is equal to the greatest of

  • * the dimensionality of all the inputs

  • * the depth to which the input list is nested

Raises

ValueError

  • If list depths are mismatched - for instance, [[a, b], c] is illegal, and should be spelt [[a, b], [c]] * If lists are empty - for instance, [[a, b], []]

See also

concatenate()

Join a sequence of arrays along an existing axis.

stack()

Join a sequence of arrays along a new axis.

vstack()

Stack arrays in sequence vertically (row wise).

hstack()

Stack arrays in sequence horizontally (column wise).

dstack()

Stack arrays in sequence depth wise (along third axis).

column_stack()

Stack 1-D arrays as columns into a 2-D array.

vsplit()

Split an array into multiple sub-arrays vertically (row-wise).

Notes

When called with only scalars, np.block is equivalent to an ndarray call. So np.block([[1, 2], [3, 4]]) is equivalent to np.array([[1, 2], [3, 4]]).

This function does not enforce that the blocks lie on a fixed grid. np.block([[a, b], [c, d]]) is not restricted to arrays of the form:

AAAbb AAAbb cccDD

But is also allowed to produce, for some a, b, c, d:

AAAbb AAAbb cDDDD

Since concatenation happens along the last axis first, block is _not_ capable of producing the following directly:

AAAbb cccbb cccDD

Matlab’s “square bracket stacking”, [A, B, ...; p, q, ...], is equivalent to np.block([[A, B, ...], [p, q, ...]]).

Examples

The most common use of this function is to build a block matrix

>>> A = np.eye(2) * 2
>>> B = np.eye(3) * 3
>>> np.block([
...     [A,               np.zeros((2, 3))],
...     [np.ones((3, 2)), B               ]
... ])
array([[2., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0.],
       [1., 1., 3., 0., 0.],
       [1., 1., 0., 3., 0.],
       [1., 1., 0., 0., 3.]])

With a list of depth 1, block can be used as hstack

>>> np.block([1, 2, 3])              # hstack([1, 2, 3])
array([1, 2, 3])
>>> a = np.array([1, 2, 3])
>>> b = np.array([2, 3, 4])
>>> np.block([a, b, 10])             # hstack([a, b, 10])
array([ 1,  2,  3,  2,  3,  4, 10])
>>> A = np.ones((2, 2), int)
>>> B = 2 * A
>>> np.block([A, B])                 # hstack([A, B])
array([[1, 1, 2, 2],
       [1, 1, 2, 2]])

With a list of depth 2, block can be used in place of vstack:

>>> a = np.array([1, 2, 3])
>>> b = np.array([2, 3, 4])
>>> np.block([[a], [b]])             # vstack([a, b])
array([[1, 2, 3],
       [2, 3, 4]])
>>> A = np.ones((2, 2), int)
>>> B = 2 * A
>>> np.block([[A], [B]])             # vstack([A, B])
array([[1, 1],
       [1, 1],
       [2, 2],
       [2, 2]])

It can also be used in places of atleast_1d and atleast_2d

>>> a = np.array(0)
>>> b = np.array([1])
>>> np.block([a])                    # atleast_1d(a)
array([0])
>>> np.block([b])                    # atleast_1d(b)
array([1])
>>> np.block([[a]])                  # atleast_2d(a)
array([[0]])
>>> np.block([[b]])                  # atleast_2d(b)
array([[1]])