jax.numpy.isscalar#
- jax.numpy.isscalar(element)[source]#
Return True if the input is a scalar.
JAX implementation of
numpy.isscalar()
. JAX’s implementation differs from NumPy’s in that it considers zero-dimensional arrays to be scalars; see the Note below for more details.- Parameters:
element (Any) – input object to check; any type is valid input.
- Returns:
True if
element
is a scalar value or an array-like object with zero dimensions, False otherwise.- Return type:
Note
JAX and NumPy differ in their representation of scalar values. NumPy has special scalar objects (e.g.
np.int32(0)
) which are distinct from zero-dimensional arrays (e.g.np.array(0)
), andnumpy.isscalar()
returnsTrue
for the former andFalse
for the latter.JAX does not define special scalar objects, but rather represents scalars as zero-dimensional arrays. As such,
jax.numpy.isscalar()
returnsTrue
for both scalar objects (e.g.0.0
ornp.float32(0.0)
) and array-like objects with zero dimensions (e.g.jnp.array(0.0)
,np.array(0.0)
).One reason for the different conventions in
isscalar
is to maintain JIT-invariance: i.e. the property that the result of a function should not change when it is JIT-compiled. Because scalar inputs are cast to zero-dimensional JAX arrays at JIT boundaries, the semantics ofnumpy.isscalar()
are such that the result changes under JIT:>>> np.isscalar(1.0) True >>> jax.jit(np.isscalar)(1.0) Array(False, dtype=bool)
By treating zero-dimensional arrays as scalars,
jax.numpy.isscalar()
avoids this issue:>>> jnp.isscalar(1.0) True >>> jax.jit(jnp.isscalar)(1.0) Array(True, dtype=bool)
Examples
In JAX, both scalars and zero-dimensional array-like objects are considered scalars:
>>> jnp.isscalar(1.0) True >>> jnp.isscalar(1 + 1j) True >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array True >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor True >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array True >>> jnp.isscalar(np.int32(1)) # NumPy scalar type True
Arrays with one or more dimension are not considered scalars:
>>> jnp.isscalar(jnp.array([1])) False >>> jnp.isscalar(np.array([1])) False
Compare this to
numpy.isscalar()
, which returnsTrue
for scalar-typed objects, andFalse
for all arrays, even those with zero dimensions:>>> np.isscalar(np.int32(1)) # scalar object True >>> np.isscalar(np.array(1)) # zero-dimensional array False
In JAX, as in NumPy, objects which are not array-like are not considered scalars:
>>> jnp.isscalar(None) False >>> jnp.isscalar([1]) False >>> jnp.isscalar(tuple()) False >>> jnp.isscalar(slice(10)) False