jax.numpy.identity#

jax.numpy.identity(n, dtype=None)[source]#

Create a square identity matrix

JAX implementation of numpy.identity().

Parameters:
  • n (DimSize) – integer specifying the size of each array dimension.

  • dtype (DTypeLike | None | None) – optional dtype; defaults to floating point.

Returns:

Identity array of shape (n, n).

Return type:

Array

See also

jax.numpy.eye(): non-square and/or offset identity matrices.

Examples

A simple 3x3 identity matrix:

>>> jnp.identity(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

A 2x2 integer identity matrix:

>>> jnp.identity(2, dtype=int)
Array([[1, 0],
       [0, 1]], dtype=int32)