jax.numpy.block#
- jax.numpy.block(arrays)[source]#
Create an array from a list of blocks.
JAX implementation of
numpy.block()
.- Parameters:
arrays (ArrayLike | list[ArrayLike]) – an array, or nested list of arrays which will be concatenated together to form the final array.
- Returns:
a single array constructed from the inputs.
- Return type:
See also
Examples
consider these blocks:
>>> zeros = jnp.zeros((2, 2)) >>> ones = jnp.ones((2, 2)) >>> twos = jnp.full((2, 2), 2) >>> threes = jnp.full((2, 2), 3)
Passing a single array to
block()
returns the array:>>> jnp.block(zeros) Array([[0., 0.], [0., 0.]], dtype=float32)
Passing a simple list of arrays concatenates them along the last axis:
>>> jnp.block([zeros, ones]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.]], dtype=float32)
Passing a doubly-nested list of arrays concatenates the inner list along the last axis, and the outer list along the second-to-last axis:
>>> jnp.block([[zeros, ones], ... [twos, threes]]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.], [2., 2., 3., 3.], [2., 2., 3., 3.]], dtype=float32)
Note that blocks need not align in all dimensions, though the size along the axis of concatenation must match. For example, this is valid because after the inner, horizontal concatenation, the resulting blocks have a valid shape for the outer, vertical concatenation.
>>> a = jnp.zeros((2, 1)) >>> b = jnp.ones((2, 3)) >>> c = jnp.full((1, 2), 2) >>> d = jnp.full((1, 2), 3) >>> jnp.block([[a, b], [c, d]]) Array([[0., 1., 1., 1.], [0., 1., 1., 1.], [2., 2., 3., 3.]], dtype=float32)
Note also that this logic generalizes to blocks in 3 or more dimensions. Here’s a 3-dimensional block-wise array:
>>> x = jnp.arange(6).reshape((1, 2, 3)) >>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)] >>> jnp.block(blocks).shape (5, 8, 9)