jax.numpy.block#

jax.numpy.block(arrays)[source]#

Create an array from a list of blocks.

JAX implementation of numpy.block().

Parameters:

arrays (ArrayLike | list[ArrayLike]) – an array, or nested list of arrays which will be concatenated together to form the final array.

Returns:

a single array constructed from the inputs.

Return type:

Array

See also

Examples

consider these blocks:

>>> zeros = jnp.zeros((2, 2))
>>> ones = jnp.ones((2, 2))
>>> twos = jnp.full((2, 2), 2)
>>> threes = jnp.full((2, 2), 3)

Passing a single array to block() returns the array:

>>> jnp.block(zeros)
Array([[0., 0.],
       [0., 0.]], dtype=float32)

Passing a simple list of arrays concatenates them along the last axis:

>>> jnp.block([zeros, ones])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.]], dtype=float32)

Passing a doubly-nested list of arrays concatenates the inner list along the last axis, and the outer list along the second-to-last axis:

>>> jnp.block([[zeros, ones],
...            [twos, threes]])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.],
       [2., 2., 3., 3.],
       [2., 2., 3., 3.]], dtype=float32)

Note that blocks need not align in all dimensions, though the size along the axis of concatenation must match. For example, this is valid because after the inner, horizontal concatenation, the resulting blocks have a valid shape for the outer, vertical concatenation.

>>> a = jnp.zeros((2, 1))
>>> b = jnp.ones((2, 3))
>>> c = jnp.full((1, 2), 2)
>>> d = jnp.full((1, 2), 3)
>>> jnp.block([[a, b], [c, d]])
Array([[0., 1., 1., 1.],
       [0., 1., 1., 1.],
       [2., 2., 3., 3.]], dtype=float32)

Note also that this logic generalizes to blocks in 3 or more dimensions. Here’s a 3-dimensional block-wise array:

>>> x = jnp.arange(6).reshape((1, 2, 3))
>>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)]
>>> jnp.block(blocks).shape
(5, 8, 9)