jax.numpy.eye#
- jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)[source]#
Create a square or rectangular identity matrix
JAX implementation of
numpy.eye()
.- Parameters:
N (DimSize) – integer specifying the first dimension of the array.
M (DimSize | None | None) – optional integer specifying the second dimension of the array; defaults to the same value as
N
.k (int | ArrayLike) – optional integer specifying the offset of the diagonal. Use positive values for upper diagonals, and negative values for lower diagonals. Default is zero.
dtype (DTypeLike | None | None) – optional dtype; defaults to floating point.
device (xc.Device | Sharding | None | None) – optional
Device
orSharding
to which the created array will be committed.
- Returns:
Identity array of shape
(N, M)
, or(N, N)
ifM
is not specified.- Return type:
See also
jax.numpy.identity()
: Simpler API for generating square identity matrices.Examples
A simple 3x3 identity matrix:
>>> jnp.eye(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Integer identity matrices with offset diagonals:
>>> jnp.eye(3, k=1, dtype=int) Array([[0, 1, 0], [0, 0, 1], [0, 0, 0]], dtype=int32) >>> jnp.eye(3, k=-1, dtype=int) Array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=int32)
Non-square identity matrix:
>>> jnp.eye(3, 5, k=1) Array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.]], dtype=float32)