jax.lib package

The jax.lib package is a set of internal tools and types for bridging between JAX’s Python frontend and its XLA backend.

jax.lib.xla_bridge

default_backend()

Returns the platform name of the default XLA backend.

device_count([backend])

Returns the total number of devices.

get_backend([platform])

get_compile_options(num_replicas, num_partitions)

Returns the compile options to use, as derived from flag values.

local_device_count([backend])

Returns the number of devices addressable by this process.

process_index([backend])

Returns the integer process index of this process.

jax.lib.xla_client

jax.lib.xla_extension

Device

A descriptor of an available device.

CpuDevice

GpuDevice

TpuDevice