jax.xla_computation#
- jax.xla_computation(fun, static_argnums=(), axis_env=None, in_parts=None, out_parts=None, backend=None, tuple_args=False, instantiate_const_outputs=None, return_shape=False, donate_argnums=())[source]#
Creates a function that produces its XLA computation given example args.
- Parameters
fun (
Callable
) – Function from which to form XLA computations.static_argnums (
Union
[int
,Iterable
[int
]]) – See thejax.jit()
docstring.axis_env (
Optional
[Sequence
[Tuple
[Hashable
,int
]]]) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications ofjax.pmap()
. See the examples below.in_parts – Optional, how each argument to
fun
should be partitioned or replicated. This is used to specify partitioned XLA computations, seesharded_jit
for more info.out_parts – Optional, how each output of
fun
should be partitioned or replicated. This is used to specify partitioned XLA computations, seesharded_jit
for more info.backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:'cpu'
,'gpu'
, or'tpu'
.tuple_args (
bool
) – Optional bool, defaults toFalse
. IfTrue
, the resulting XLA computation will have a single tuple argument that is unpacked into the specified function arguments. If None, tupling will be enabled when there are more than 100 arguments, since some platforms have limits on argument arity.instantiate_const_outputs (
Optional
[bool
]) – Deprecated argument, does nothing.return_shape (
bool
) – Optional boolean, defaults toFalse
. IfTrue
, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output offun
and where the leaves are objects withshape
,dtype
, andnamed_shape
attributes representing the corresponding types of the output leaves.donate_argnums (
Union
[int
,Iterable
[int
]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.
- Return type
- Returns
A wrapped version of
fun
that when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods likeas_hlo_text
,as_serialized_hlo_module_proto
, andas_hlo_dot_graph
. If the argumentreturn_shape
isTrue
, then the wrapped function returns a pair where the first element is the XLA Computation and the second element is a pytree representing the structure, shapes, dtypes, and named shapes of the output offun
.Concrete example arguments are not always necessary. For those arguments not indicated by
static_argnums
, any object withshape
anddtype
attributes is acceptable (excepting namedtuples, which are treated as Python containers).
For example:
>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> c = jax.xla_computation(f)(3.) >>> print(c.as_hlo_text()) HloModule xla_computation_f.6 ENTRY xla_computation_f.6 { constant.2 = pred[] constant(false) parameter.1 = f32[] parameter(0) cosine.3 = f32[] cosine(parameter.1) sine.4 = f32[] sine(cosine.3) ROOT tuple.5 = (f32[]) tuple(sine.4) }
Alternatively, the assignment to
c
above could be written:>>> import types >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> c = jax.xla_computation(f)(scalar)
Here’s an example that involves a parallel collective and axis name:
>>> def f(x): return x - jax.lax.psum(x, 'i') >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) >>> print(c.as_hlo_text()) HloModule jaxpr_computation.9 primitive_computation.3 { parameter.4 = s32[] parameter(0) parameter.5 = s32[] parameter(1) ROOT add.6 = s32[] add(parameter.4, parameter.5) } ENTRY jaxpr_computation.9 { tuple.1 = () tuple() parameter.2 = s32[] parameter(0) all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7) }
Notice the
replica_groups
that were generated. Here’s an example that generates more interestingreplica_groups
:>>> from jax import lax >>> def g(x): ... rowsum = lax.psum(x, 'i') ... colsum = lax.psum(x, 'j') ... allsum = lax.psum(x, ('i', 'j')) ... return rowsum, colsum, allsum ... >>> axis_env = [('i', 4), ('j', 2)] >>> c = xla_computation(g, axis_env=axis_env)(5.) >>> print(c.as_hlo_text()) HloModule jaxpr_computation__1.19 [removed uninteresting text here] ENTRY jaxpr_computation__1.19 { tuple.1 = () tuple() parameter.2 = f32[] parameter(0) all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17) }