jax.Device

jax.Device#

class jax.Device#

A descriptor of an available device.

Subclasses are used to represent specific types of devices, e.g. CPUs, GPUs. Subclasses may have additional properties specific to that device type.

__init__(*args, **kwargs)#

Methods

__init__(*args, **kwargs)

addressable_memories(self)

default_memory(self)

get_stream_for_external_ready_events(self)

live_buffers(self)

memory(self, kind)

memory_stats(self)

Returns memory statistics for this device keyed by name.

transfer_from_outfeed(self, arg0)

transfer_to_infeed(self, arg0)

Attributes

client

device_kind

host_id

Deprecated; please use process_index

id

Integer ID of this device.

local_hardware_id

Opaque hardware ID, e.g., the CUDA device number.

platform

process_index

Integer index of this device's process.

task_id

Deprecated; please use process_index