Source code for xma.accelerator

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

from enum import Enum
from functools import lru_cache

import torch

from .utils import is_torch_neuronx_available, is_torch_xla_available


if is_torch_xla_available():
    from torch_xla.core.xla_model import wait_device_ops as xla_wait_device_ops
    from torch_xla.core.xla_model import xla_device


_IS_ROCM_AVAILABLE = torch.version.hip is not None


[docs] class KernelBackend(Enum): cuda = "cuda" nki = "nki" pallas = "pallas" rocm = "rocm" torch = "torch" triton = "triton"
[docs] def get_compatible_accelerator(self) -> Accelerator: found_accelerator = Accelerator.get_accelerator() if self == KernelBackend.torch or ( self == KernelBackend.triton and found_accelerator in [Accelerator.cuda, Accelerator.rocm] ): return found_accelerator mapping = { KernelBackend.cuda: Accelerator.cuda, KernelBackend.nki: Accelerator.trainium, KernelBackend.pallas: Accelerator.tpu, KernelBackend.rocm: Accelerator.rocm, } return mapping.get(self, None)
[docs] def verify_accelerator(self) -> bool: expected_accelerator = self.get_compatible_accelerator() found_accelerator = Accelerator.get_accelerator() return expected_accelerator == found_accelerator
[docs] class Accelerator(Enum): cpu = "cpu" cuda = "cuda" rocm = "rocm" tpu = "tpu" trainium = "trainium"
[docs] @staticmethod @lru_cache def get_accelerator() -> Accelerator: if is_torch_xla_available(): accelerator = Accelerator.tpu elif is_torch_neuronx_available(): accelerator = Accelerator.trainium elif torch.cuda.is_available(): accelerator = Accelerator.rocm if _IS_ROCM_AVAILABLE else Accelerator.cuda else: accelerator = Accelerator.cpu return accelerator
[docs] @staticmethod def get_current_device() -> int | str: accelerator = Accelerator.get_accelerator() if accelerator in [Accelerator.cuda, Accelerator.rocm]: device = torch.cuda.current_device() elif accelerator == Accelerator.tpu: device = xla_device() elif accelerator == Accelerator.trainium: device = torch.neuron.current_device() elif accelerator == Accelerator.cpu: device = "cpu" return device
[docs] @staticmethod @lru_cache def get_kernel_backend() -> KernelBackend: accelerator = Accelerator.get_accelerator() if accelerator == Accelerator.cuda: kernel_backend = KernelBackend.rocm if _IS_ROCM_AVAILABLE else KernelBackend.cuda elif accelerator == Accelerator.tpu: kernel_backend = KernelBackend.pallas elif accelerator == Accelerator.trainium: kernel_backend = KernelBackend.nki else: kernel_backend = KernelBackend.triton return kernel_backend
[docs] @staticmethod def synchronize() -> None: accelerator = Accelerator.get_accelerator() if accelerator == Accelerator.cuda: torch.cuda.synchronize() elif accelerator == Accelerator.tpu: xla_wait_device_ops()
[docs] @staticmethod def get_sm_count(device: torch.device | None = None) -> int: if device is None: accelerator = Accelerator.get_accelerator() else: accelerator = Accelerator(device.type) # TODO clean this up if accelerator == Accelerator.cuda: sm_count = torch.cuda.get_device_properties(device).multi_processor_count elif device.type == "xpu": sm_count = torch.xpu.get_device_properties(device).gpu_subslice_count return sm_count