jax.numpy.eye

Contents

jax.numpy.eye#

jax.numpy.eye(N, M=None, k=0, dtype=None)[source]#

Return a 2-D array with ones on the diagonal and zeros elsewhere.

LAX-backend implementation of numpy.eye().

Original docstring below.

Parameters:
  • N (int) – Number of rows in the output.

  • M (int, optional) – Number of columns in the output. If None, defaults to N.

  • k (int, optional) – Index of the diagonal: 0 (the default) refers to the main diagonal, a positive value refers to an upper diagonal, and a negative value to a lower diagonal.

  • dtype (data-type, optional) – Data-type of the returned array.

Returns:

I – An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.

Return type:

ndarray of shape (N,M)