jax.dtypes module#

bfloat16

bfloat16 floating-point values

canonicalize_dtype(dtype[, allow_opaque_dtype])

Convert from a dtype to a canonical dtype based on config.x64_enabled.

float0

issubdtype(a, b)

Returns True if first argument is a typecode lower/equal in type hierarchy.

result_type(*args[, return_weak_type_flag])

Convenience function to apply JAX argument dtype promotion.

scalar_type_of(x)

Return the scalar type associated with a JAX value.