jax.numpy.identityΒΆ

jax.numpy.identity(n, dtype=None)[source]ΒΆ

Return the identity array.

LAX-backend implementation of identity(). Original docstring below.

The identity array is a square array with ones on the main diagonal.

Parameters
  • n (int) – Number of rows (and columns) in n x n output.

  • dtype (data-type, optional) – Data-type of the output. Defaults to float.

Returns

out – n x n array with its main diagonal set to one, and all other elements 0.

Return type

ndarray

Examples

>>> np.identity(3)
array([[1.,  0.,  0.],
       [0.,  1.,  0.],
       [0.,  0.,  1.]])