jax.numpy.size#
- jax.numpy.size(a, axis=None)#
Return the number of elements along a given axis.
- Parameters:
a (array_like) – Input data.
axis (int, optional) – Axis along which the elements are counted. By default, give the total number of elements.
- Returns:
element_count – Number of elements along the specified axis.
- Return type:
See also
shape
dimensions of array
ndarray.shape
dimensions of array
ndarray.size
number of elements in array
Examples
>>> a = np.array([[1,2,3],[4,5,6]]) >>> np.size(a) 6 >>> np.size(a,1) 3 >>> np.size(a,0) 2