jax.device_put

Contents

jax.device_put#

jax.device_put(x, device=None, *, src=None)[source]#

Transfers x to device.

Parameters:
  • x – An array, scalar, or (nested) standard Python container thereof.

  • device (None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind) – The (optional) Device, Sharding, or a (nested) Sharding in standard Python container (must be a tree prefix of x), representing the device(s) to which x should be transferred. If given, then the result is committed to the device(s).

Returns:

A copy of x that resides on device.

If the device parameter is None, then this operation behaves like the identity function if the operand is on any device already, otherwise it transfers the data to the default device, uncommitted.

For more details on data placement see the FAQ on data placement.

This function is always asynchronous, i.e. returns immediately without blocking the calling Python thread until any transfers are completed.

Parameters:

src (None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind)