jax.lib.xla_bridge.get_compile_optionsΒΆ

jax.lib.xla_bridge.get_compile_options(num_replicas, num_partitions, device_assignment=None, use_spmd_partitioning=True)[source]ΒΆ

Returns the compile options to use, as derived from flag values.

Parameters
  • num_replicas (int) – Number of replicas for which to compile.

  • num_partitions (int) – Number of partitions for which to compile.

  • device_assignment – Optional tuple of integers indicating the assignment of logical replicas to physical devices (default inherited from xla_client.CompileOptions). Must be consistent with num_replicas and num_partitions.

  • use_spmd_partitioning (bool) – boolean indicating whether to enable SPMD or MPMD partitioning in XLA.

Return type

CompileOptions