jax.Device

jax.Device#

class jax.Device#

A descriptor of an available device.

Subclasses are used to represent specific types of devices, e.g. CPUs, GPUs. Subclasses may have additional properties specific to that device type.

__init__(*args, **kwargs)#

Methods

__init__(*args, **kwargs)

Attributes

addressable_memories

Returns all the memories that a device can address.

client

(self) -> object

default_memory

Returns the default memory of a device.

device_kind

(self) -> str

get_stream_for_external_ready_events

host_id

Deprecated; please use process_index

id

Integer ID of this device.

live_buffers

local_hardware_id

Opaque hardware ID, e.g., the CUDA device number.

memory

memory_stats

Returns memory statistics for this device keyed by name.

platform

(self) -> str

process_index

Integer index of this device's process.

task_id

Deprecated; please use process_index

transfer_from_outfeed

transfer_to_infeed