jax.numpy.eye#
- jax.numpy.eye(N, M=None, k=0, dtype=None)[source]#
Return a 2-D array with ones on the diagonal and zeros elsewhere.
LAX-backend implementation of
numpy.eye()
.Original docstring below.
- Parameters
N (int) – Number of rows in the output.
M (int, optional) – Number of columns in the output. If None, defaults to N.
k (int, optional) – Index of the diagonal: 0 (the default) refers to the main diagonal, a positive value refers to an upper diagonal, and a negative value to a lower diagonal.
dtype (data-type, optional) – Data-type of the returned array.
- Returns
I – An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
- Return type
ndarray of shape (N,M)