jax.xla_computation

jax.xla_computation#

jax.xla_computation(fun, static_argnums=(), axis_env=None, in_parts=None, out_parts=None, backend=None, tuple_args=False, instantiate_const_outputs=None, return_shape=False, donate_argnums=())[source]#

Creates a function that produces its XLA computation given example args.

Parameters:
  • fun (Callable) – Function from which to form XLA computations.

  • static_argnums (int | Iterable[int]) – See the jax.jit() docstring.

  • axis_env (Sequence[tuple[AxisName, int]] | None) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications of jax.pmap(). See the examples below.

  • in_parts – Optional, how each argument to fun should be partitioned or replicated. This is used to specify partitioned XLA computations, see sharded_jit for more info.

  • out_parts – Optional, how each output of fun should be partitioned or replicated. This is used to specify partitioned XLA computations, see sharded_jit for more info.

  • backend (str | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

  • tuple_args (bool) – Optional bool, defaults to False. If True, the resulting XLA computation will have a single tuple argument that is unpacked into the specified function arguments. If None, tupling will be enabled when there are more than 100 arguments, since some platforms have limits on argument arity.

  • instantiate_const_outputs (bool | None) – Deprecated argument, does nothing.

  • return_shape (bool) – Optional boolean, defaults to False. If True, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output of fun and where the leaves are objects with shape, dtype, and named_shape attributes representing the corresponding types of the output leaves.

  • donate_argnums (int | Iterable[int]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.

Return type:

Callable

Returns:

A wrapped version of fun that when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods like as_hlo_text, as_serialized_hlo_module_proto, and as_hlo_dot_graph. If the argument return_shape is True, then the wrapped function returns a pair where the first element is the XLA Computation and the second element is a pytree representing the structure, shapes, dtypes, and named shapes of the output of fun.

Concrete example arguments are not always necessary. For those arguments not indicated by static_argnums, any object with shape and dtype attributes is acceptable (excepting namedtuples, which are treated as Python containers).

For example:

>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> c = jax.xla_computation(f)(3.)
>>> print(c.as_hlo_text())  
HloModule xla_computation_f.6

ENTRY xla_computation_f.6 {
  constant.2 = pred[] constant(false)
  parameter.1 = f32[] parameter(0)
  cosine.3 = f32[] cosine(parameter.1)
  sine.4 = f32[] sine(cosine.3)
  ROOT tuple.5 = (f32[]) tuple(sine.4)
}

Alternatively, the assignment to c above could be written:

>>> import types
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> c = jax.xla_computation(f)(scalar)

Here’s an example that involves a parallel collective and axis name:

>>> def f(x): return x - jax.lax.psum(x, 'i')
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)
>>> print(c.as_hlo_text())  
HloModule jaxpr_computation.9
primitive_computation.3 {
  parameter.4 = s32[] parameter(0)
  parameter.5 = s32[] parameter(1)
  ROOT add.6 = s32[] add(parameter.4, parameter.5)
}
ENTRY jaxpr_computation.9 {
  tuple.1 = () tuple()
  parameter.2 = s32[] parameter(0)
  all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
  ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
}

Notice the replica_groups that were generated. Here’s an example that generates more interesting replica_groups:

>>> from jax import lax
>>> def g(x):
...   rowsum = lax.psum(x, 'i')
...   colsum = lax.psum(x, 'j')
...   allsum = lax.psum(x, ('i', 'j'))
...   return rowsum, colsum, allsum
...
>>> axis_env = [('i', 4), ('j', 2)]
>>> c = xla_computation(g, axis_env=axis_env)(5.)
>>> print(c.as_hlo_text())  
HloModule jaxpr_computation__1.19
[removed uninteresting text here]
ENTRY jaxpr_computation__1.19 {
  tuple.1 = () tuple()
  parameter.2 = f32[] parameter(0)
  all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
  all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
  all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
  ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
}