jax.scipy.linalg.block_diag

Contents

jax.scipy.linalg.block_diag#

jax.scipy.linalg.block_diag(*arrs)[source]#

Create a block diagonal matrix from input arrays.

JAX implementation of scipy.linalg.block_diag().

Parameters:

*arrs (jax.typing.ArrayLike) – arrays of at most two dimensions

Returns:

2D block-diagonal array constructed by placing the input arrays along the diagonal.

Return type:

Array

Example

>>> A = jnp.ones((1, 1))
>>> B = jnp.ones((2, 2))
>>> C = jnp.ones((3, 3))
>>> jax.scipy.linalg.block_diag(A, B, C)
Array([[1., 0., 0., 0., 0., 0.],
       [0., 1., 1., 0., 0., 0.],
       [0., 1., 1., 0., 0., 0.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.]], dtype=float32)