jax.dlpack.to_dlpack

Contents

jax.dlpack.to_dlpack#

jax.dlpack.to_dlpack(x, stream=None, src_device=None, dl_device=None, max_version=None, copy=None)[source]#

Returns a DLPack tensor that encapsulates a Array x.

Parameters:
  • x (Array) – a Array, on either CPU or GPU.

  • stream (int | Any | None) – optional platform-dependent stream to wait on until the buffer is ready. This corresponds to the stream argument to __dlpack__ documented in https://dmlc.github.io/dlpack/latest/python_spec.html.

  • src_device (xla_client.Device | None) – either a CPU or GPU Device.

  • dl_device (tuple[DLDeviceType, int] | None) – a tuple of (dl_device_type, local_hardware_id) in DLPack format e.g. as produced by __dlpack_device__.

  • max_version (tuple[int, int] | None) – the maximum DLPack version that the consumer (i.e. caller of __dlpack__) supports in the form of a 2-tuple of (major, minor). This function is not guaranteed to return a capsule of version max_version.

  • copy (bool | None) – a boolean indicating whether or not to copy the input. If copy=True then the function must always copy. When copy=False then the function must never copy, and must raise an error when a copy is deemed necessary. If copy=None then the function must avoid a copy if possible but may copy if needed.

Returns:

A DLPack PyCapsule object.

Note

While JAX arrays are always immutable, DLPackManagedTensor buffers cannot be marked as immutable, and it is possible for processes external to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array is mutated, it may lead to undefined behavior when using the associated JAX array. When JAX eventually supports DLManagedTensorVersioned (DLPack 1.0), it will be possible to specify that a buffer is read-only.