jax.numpy.shape

jax.numpy.shape(a)

Return the shape of an array.

Parameters

a (array_like) – Input array.

Returns

shape – The elements of the shape tuple give the lengths of the corresponding array dimensions.

Return type

tuple of ints

See also

alen()

ndarray.shape()

Equivalent array method.

Examples

>>> np.shape(np.eye(3))
(3, 3)
>>> np.shape([[1, 2]])
(1, 2)
>>> np.shape([0])
(1,)
>>> np.shape(0)
()
>>> a = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')])
>>> np.shape(a)
(2,)
>>> a.shape
(2,)