jax.numpy.shape

Contents

jax.numpy.shape#

jax.numpy.shape(a)[source]#

Return the shape of an array.

Parameters:

a (array_like) – Input array.

Returns:

shape – The elements of the shape tuple give the lengths of the corresponding array dimensions.

Return type:

tuple of ints

See also

len

len(a) is equivalent to np.shape(a)[0] for N-D arrays with N>=1.

ndarray.shape

Equivalent array method.

Examples

>>> import numpy as np
>>> np.shape(np.eye(3))
(3, 3)
>>> np.shape([[1, 3]])
(1, 2)
>>> np.shape([0])
(1,)
>>> np.shape(0)
()
>>> a = np.array([(1, 2), (3, 4), (5, 6)],
...              dtype=[('x', 'i4'), ('y', 'i4')])
>>> np.shape(a)
(3,)
>>> a.shape
(3,)