jax.numpy.eyeΒΆ

jax.numpy.eye(N, M=None, k=0, dtype=None)[source]ΒΆ

Return a 2-D array with ones on the diagonal and zeros elsewhere.

LAX-backend implementation of eye(). Original docstring below.

Parameters
  • N (int) –

  • M (int, optional) –

  • k (int, optional) –

  • dtype (data-type, optional) –

Returns

I – An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.

Return type

ndarray of shape (N,M)

See also

identity()

(almost) equivalent function

diag()

diagonal 2-D array from a 1-D array specified by the user.

Examples

>>> np.eye(2, dtype=int)
array([[1, 0],
       [0, 1]])
>>> np.eye(3, k=1)
array([[0.,  1.,  0.],
       [0.,  0.,  1.],
       [0.,  0.,  0.]])