jax.experimental.multihost_utils.global_array_to_host_local_array

jax.experimental.multihost_utils.global_array_to_host_local_array#

jax.experimental.multihost_utils.global_array_to_host_local_array(global_inputs, global_mesh, pspecs)[source]#

Converts a global jax.Array to a host local jax.Array.

You can use this function to transition to jax.Array. Using jax.Array with pjit has the same semantics of using GDA with pjit i.e. all jax.Array inputs to pjit should be globally shaped and the output from pjit will also be globally shaped jax.Array’s

You can use this function to convert the globally shaped jax.Array output from pjit to host local values again so that the transition to jax.Array can be a mechanical change. Example usage

>> from jax.experimental import multihost_utils # doctest: +SKIP >> >> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP >> >> with mesh: # doctest: +SKIP >> global_out = pjitted_fun(global_inputs) # doctest: +SKIP >> >> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP

Parameters:
  • global_inputs (Any) – A Pytree of global jax.Array’s.

  • global_mesh (Mesh) – A jax.sharding.Mesh object.

  • pspecs (Any) – A Pytree of jax.sharding.PartitionSpec’s.