jax.lib module

jax.lib module#

The jax.lib package is a set of internal tools and types for bridging between JAX’s Python frontend and its XLA backend.

jax.lib.xla_bridge#

default_backend()

Returns the platform name of the default XLA backend.

get_backend([platform])

param platform:

get_compile_options(num_replicas, num_partitions)

Returns the compile options to use, as derived from flag values.

jax.lib.xla_client#