Source code for xma.custom_op

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

import inspect
from typing import Any, Callable, Iterable, Sequence

import torch

from .accelerator import Accelerator, KernelBackend
from .constants import LIBRARY_NAME
from .counters import increment_counter


[docs] def ctx_needs_gradients(ctx) -> bool: return any(ctx.needs_input_grad)
[docs] def ctx_save_for_backward(ctx, *args) -> None: if ctx_needs_gradients(ctx): ctx.save_for_backward(*args)
[docs] class CustomOp(torch.autograd.Function):
[docs] @classmethod def run(cls, kernel_backend: KernelBackend | None = None, **kwargs) -> Any: if kernel_backend is None: kernel_backend = Accelerator.get_kernel_backend() else: assert kernel_backend.verify_accelerator() if kernel_backend is None: raise ValueError("code is not supposed to reach here! kernel_backend was not inferrable") increment_counter(cls._get_key(kernel_backend)) output = ( cls.forward_backward_torch(**kwargs) if kernel_backend == KernelBackend.torch else cls.apply(*tuple(kwargs.values()), kernel_backend) ) return output
[docs] @staticmethod def forward(ctx, *args, kernel_backend: KernelBackend) -> Any: raise NotImplementedError
[docs] @staticmethod def backward(ctx, *grad_outputs) -> Any: raise NotImplementedError
[docs] @staticmethod def forward_backward_torch(*args, **kwargs) -> Any: raise NotImplementedError
@classmethod def _get_key(cls, kernel_backend: KernelBackend) -> str: return f"{cls.__name__}-{kernel_backend.value}"
[docs] def xma_op( mutates_args: str | Iterable[str] = None, device_types: str | Sequence[str] | None = None, schema: str | None = None, fake_func: Callable | None = None, ) -> Callable: def _inner(func: Callable): custom_op = torch.library.custom_op( f"{LIBRARY_NAME}::{func.__name__}", func, mutates_args=mutates_args, device_types=device_types, schema=schema, ) if fake_func is not None: custom_op.register_fake(fake_func) def _run(*args, **kwargs): return custom_op(*args, **kwargs) _run.__signature__ = inspect.signature(func) _run.__name__ = func.__name__ return _run return _inner