jax.dtypes module

jax.dtypes module#

bfloat16

bfloat16 floating-point values

canonicalize_dtype(dtype[, allow_extended_dtype])

Convert from a dtype to a canonical dtype based on config.x64_enabled.

float0

DType class corresponding to the scalar type and dtype of the same name.

issubdtype(a, b)

Returns True if first argument is a typecode lower/equal in type hierarchy.

prng_key()

Scalar class for PRNG Key dtypes.

result_type(*args[, return_weak_type_flag])

Convenience function to apply JAX argument dtype promotion.

scalar_type_of(x)

Return the scalar type associated with a JAX value.