xma.accelerator

class Accelerator(*values)[source]

Bases: Enum

cpu = 'cpu'
cuda = 'cuda'
static get_accelerator() Accelerator[source]
static get_current_device() int | str[source]
static get_kernel_backend() KernelBackend[source]
static get_sm_count(device: device | None = None) int[source]
rocm = 'rocm'
static synchronize() None[source]
tpu = 'tpu'
trainium = 'trainium'
class KernelBackend(*values)[source]

Bases: Enum

cuda = 'cuda'
get_compatible_accelerator() Accelerator[source]
nki = 'nki'
pallas = 'pallas'
rocm = 'rocm'
torch = 'torch'
triton = 'triton'
verify_accelerator() bool[source]