jax.numpy.identityΒΆ
-
jax.numpy.
identity
(n, dtype=None)[source]ΒΆ Return the identity array.
LAX-backend implementation of
identity()
. Original docstring below.The identity array is a square array with ones on the main diagonal.
- Parameters
n (int) β Number of rows (and columns) in n x n output.
dtype (data-type, optional) β Data-type of the output. Defaults to
float
.
- Returns
out β n x n array with its main diagonal set to one, and all other elements 0.
- Return type
Examples
>>> np.identity(3) array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])