jax.numpy.sizeΒΆ

jax.numpy.size(a, axis=None)ΒΆ

Return the number of elements along a given axis.

Parameters
  • a (array_like) – Input data.

  • axis (int, optional) – Axis along which the elements are counted. By default, give the total number of elements.

Returns

element_count – Number of elements along the specified axis.

Return type

int

See also

shape

dimensions of array

ndarray.shape

dimensions of array

ndarray.size

number of elements in array

Examples

>>> a = np.array([[1,2,3],[4,5,6]])
>>> np.size(a)
6
>>> np.size(a,1)
3
>>> np.size(a,0)
2