jax.dtypes.issubdtype#
- jax.dtypes.issubdtype(a, b)[source]#
Returns True if first argument is a typecode lower/equal in type hierarchy.
This is like
numpy.issubdtype()
, but can handle dtype extensions such asjax.dtypes.bfloat16
and jax.dtypes.prng_key.- Parameters:
a (DTypeLike | ExtendedDType | None)
b (DTypeLike | ExtendedDType | None)
- Return type: