jax.lax.with_sharding_constraint#
- jax.lax.with_sharding_constraint(x, shardings)[source]#
Mechanism to constrain the sharding of an Array inside a jitted computation
This is a strict constraint for the GSPMD partitioner and not a hint. For examples of how to use this function, see Distributed arrays and automatic parallelization.
- Parameters:
x – PyTree of jax.Arrays which will have their shardings constrained
shardings – PyTree of sharding specifications. Valid values are the same as for the
in_shardings
argument ofjax.experimental.pjit()
.
- Returns:
PyTree of jax.Arrays with specified sharding constraints.
- Return type:
x_with_shardings