jax.lax.with_sharding_constraint

jax.lax.with_sharding_constraint#

jax.lax.with_sharding_constraint(x, shardings)[source]#

Mechanism to constrain the sharding of an Array inside a jitted computation

This is a strict constraint for the GSPMD partitioner and not a hint. For examples of how to use this function, see Distributed arrays and automatic parallelization.

Parameters:
  • x – PyTree of jax.Arrays which will have their shardings constrained

  • shardings – PyTree of sharding specifications. Valid values are the same as for the in_shardings argument of jax.experimental.pjit().

Returns:

PyTree of jax.Arrays with specified sharding constraints.

Return type:

x_with_shardings