jax.numpy.ndim

jax.numpy.ndim(a)

Return the number of dimensions of an array.

Parameters

a (array_like) – Input array. If it is not already an ndarray, a conversion is attempted.

Returns

number_of_dimensions – The number of dimensions in a. Scalars are zero-dimensional.

Return type

int

See also

ndarray.ndim()

equivalent method

shape()

dimensions of array

ndarray.shape()

dimensions of array

Examples

>>> np.ndim([[1,2,3],[4,5,6]])
2
>>> np.ndim(np.array([[1,2,3],[4,5,6]]))
2
>>> np.ndim(1)
0