jax.numpy.identity#
- jax.numpy.identity(n, dtype=None)[source]#
Return the identity array.
LAX-backend implementation of
numpy.identity()
.Original docstring below.
The identity array is a square array with ones on the main diagonal.
- Parameters:
n (int) – Number of rows (and columns) in n x n output.
dtype (data-type, optional) – Data-type of the output. Defaults to
float
.
- Returns:
out – n x n array with its main diagonal set to one, and all other elements 0.
- Return type:
ndarray