jax.dlpack.from_dlpack

Contents

jax.dlpack.from_dlpack#

jax.dlpack.from_dlpack(external_array, device=None, copy=None)[source]#

Returns a Array representation of a DLPack tensor.

The returned Array shares memory with external_array if no device transfer or copy was requested.

Parameters:
  • external_array – An array object that has __dlpack__ and __dlpack_device__ methods, or a DLPack tensor on either CPU or GPU (legacy API).

  • device (Device | Sharding | None) – The (optional) Device, representing the device on which the returned array should be placed. If given, then the result is committed to the device. If unspecified, the resulting array will be unpacked onto the same device it originated from. Setting device to a device different from the source of external_array will require a copy, meaning copy must be set to either True or None.

  • copy (bool | None) – An (optional) boolean, controlling whether or not a copy is performed. If copy=True then a copy is always performed, even if unpacked onto the same device. If copy=False then the copy is never performed and will raise an error if necessary. When copy=None then a copy may be performed if needed for a device transfer.

Returns:

A jax.Array

Note

While JAX arrays are always immutable, dlpack buffers cannot be marked as immutable, and it is possible for processes external to JAX to mutate them in-place. If a jax Array is constructed from a dlpack buffer and the buffer is later modified in-place, it may lead to undefined behavior when using the associated JAX array.