Source code for xma.xtuner.tuner

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

import inspect
from collections import defaultdict
from typing import Any, Callable

import torch

from ..accelerator import Accelerator
from ..utils import get_boolean_env_variable
from .cache import get_xtune_cache
from .config import XTuneConfig
from .parameter import XTuneParameter


_XTUNE_PRINT_AUTOTUNING = get_boolean_env_variable("XTUNE_PRINT_AUTOTUNING", False)
_SEPARATOR = "."
_DEFAULT_WARMUP_ITERATIONS = 5
_BENCHMARK_ITERATIONS = 10


[docs] class XTunedFunction: def __init__( self, function: Callable, configs: list[XTuneConfig], triggers: set[str], warmup_iterations: int, benchmark_iterations: int, functional_triggers: dict[str, Callable] = {}, reset_to_zero: dict = {}, ) -> XTunedFunction: assert len(configs) > 0, "no xtune config is passed" self.function = function self.configs = configs self.warmup_iterations = warmup_iterations self.benchmark_iterations = benchmark_iterations self.signature = inspect.getfullargspec(function) self.xtuneable_parameters = set(self.configs[0].get_key_values().keys()) self._setup_trigger_map(triggers) for config in self.configs: assert ( set(config.get_key_values().keys()) == self.xtuneable_parameters ), "xtune configs don't match the expected function signature" self.functional_triggers = functional_triggers self.reset_to_zero = reset_to_zero self.filename = inspect.stack()[2].filename.rsplit("xma", 1)[1][1:] self.function_hash = f"{self.filename}->{function.__name__}" self.function_cache = {} def __call__(self, *args, **kwargs) -> Any: override_xtune_parameters = self._can_override_variables(*args, **kwargs) lookup_key = self._get_lookup_key(*args, **kwargs) best_config = self.function_cache.get(lookup_key, None) if best_config is None: # bypass xtune for single config if len(self.configs) == 1: best_config = self.configs[0] best_time = 0 else: best_config, best_time, _ = self._xtune(*args, **kwargs) self.function_cache[lookup_key] = best_config get_xtune_cache().add_config(function_hash=self.function_hash, lookup_key=lookup_key, config=best_config) if _XTUNE_PRINT_AUTOTUNING and ( not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ): print( f"config {best_config} achieved the best time ({best_time} sec) for {lookup_key} for " f"function {self.function.__name__}" ) return self.function( **self._get_function_arguments( config=best_config, args=args, kwargs=kwargs, override_allowed=override_xtune_parameters ) ) def _can_override_variables(self, *args, **kwargs) -> bool: num_xtune_parameters_found = 0 num_specified_parameters_found = 0 for i in range(len(args)): variable_name = self.signature.args[i] is_tuneable_variable = variable_name in self.xtuneable_parameters if isinstance(args[i], XTuneParameter): assert is_tuneable_variable, "argument with XTuneParameter() value should be a tuned parameter" num_xtune_parameters_found += 1 elif is_tuneable_variable: num_specified_parameters_found += 1 # accessing kwargs.items() breaks torch.compile in backwards of a custom autograd function for variable_name in kwargs: is_tuneable_variable = variable_name in self.xtuneable_parameters if isinstance(kwargs.get(variable_name), XTuneParameter): assert is_tuneable_variable, "argument with XTuneParameter() value should be a tuned parameter" num_xtune_parameters_found += 1 elif is_tuneable_variable: num_specified_parameters_found += 1 n = len(self.xtuneable_parameters) if num_xtune_parameters_found == 0: assert num_specified_parameters_found in [0, n] return num_specified_parameters_found == n assert ( num_specified_parameters_found == 0 ), "if one tuneable parameter is specified, all others must be specified" assert ( num_xtune_parameters_found == n ), "all tuneable parameters should be set to XTuneParameter() if even one is set to XTuneParameter()" return False def _get_function_arguments(self, config: XTuneConfig, args: list, kwargs: dict, override_allowed: bool) -> dict: # copy the best_config first so we can override with args or kwargs result = {variable_name: value for variable_name, value in config.get_key_values().items()} for i in range(len(args)): variable_name = self.signature.args[i] if override_allowed or variable_name not in result: result[variable_name] = args[i] # accessing kwargs.items() breaks torch.compile in backwards of a custom autograd function for variable_name in kwargs: if override_allowed or variable_name not in result: result[variable_name] = kwargs.get(variable_name) return result @torch.compiler.set_stance("force_eager") @torch.inference_mode() def _xtune(self, *args, **kwargs) -> tuple[XTuneConfig, float, list[tuple[XTuneConfig, float]]]: best_config = None best_time = float("inf") timed_configs = [] for config in self.configs: if not config.is_condition_valid( **self._get_function_arguments( config=XTuneConfig({}), args=args, kwargs=kwargs, override_allowed=False ) ): if _XTUNE_PRINT_AUTOTUNING: print(f"Skipping config {config} for function {self.function.__name__}") continue if _XTUNE_PRINT_AUTOTUNING: print(f"Autotuning function {self.function.__name__} with config {config}") elapsed_time = self._run_benchmark( **self._get_function_arguments(config=config, args=args, kwargs=kwargs, override_allowed=False), ) timed_configs.append((config, elapsed_time)) if elapsed_time < best_time: best_config = config best_time = elapsed_time assert best_config is not None, "no best_config found, check that at least 1 xtune config is valid" return best_config, best_time, timed_configs def _get_lookup_key(self, *args, **kwargs) -> Any: lookup_key = [] def _maybe_add_key(variable_name: str, value) -> None: if variable_name not in self.variable_name_trigger_map: return triggers = self.variable_name_trigger_map[variable_name] if isinstance(value, torch.Tensor): for func_name, func in triggers: if func is None: assert len(triggers) == 1 func = lambda tensor: (tensor.dtype, tensor.size(), tensor.stride()) lookup_key.append(f"{variable_name}.{func_name} = {func(value)}") else: assert len(triggers) == 1 func_name, func = triggers[0] assert ( func is None ), f"trigger ({variable_name}) is not a tensor and shouldn't have a functional trigger" lookup_key.append(f"{variable_name} = {value}") for i, value in enumerate(args): variable_name = self.signature.args[i] _maybe_add_key(variable_name, value) for variable_name in kwargs: _maybe_add_key(variable_name, kwargs[variable_name]) # now run the functional triggers if len(self.functional_triggers) > 0: kwargs = self._get_function_arguments( config=XTuneConfig({}), args=args, kwargs=kwargs, override_allowed=False ) for variable_name, func in self.functional_triggers.items(): lookup_key.append(f"{variable_name} = {func(**kwargs)}") return str(lookup_key)[1:-1] def _run_benchmark(self, **kwargs: dict) -> float: Accelerator.synchronize() for _ in range(self.warmup_iterations): self.function(**kwargs) # TODO generalize to device Event start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) if len(self.reset_to_zero) > 0: elapsed_time = 0 for _ in range(self.benchmark_iterations): start.record() self.function(**kwargs) end.record() Accelerator.synchronize() elapsed_time += start.elapsed_time(end) for variable_name, function in self.reset_to_zero.items(): if function is None or function(**kwargs): variable = kwargs[variable_name] assert isinstance(variable, torch.Tensor) variable.zero_() else: start.record() for _ in range(self.benchmark_iterations): self.function(**kwargs) end.record() Accelerator.synchronize() elapsed_time = start.elapsed_time(end) return elapsed_time / self.benchmark_iterations def _setup_trigger_map(self, triggers: set[str]) -> None: assert isinstance(triggers, set), "triggers should be a set" self.variable_name_trigger_map = defaultdict(list) for trigger in triggers: variable_name, func_name, func = self._parse_trigger(trigger) self.variable_name_trigger_map[variable_name].append((func_name, func)) # filter to remove all triggers if None, this is useful for Tensor based triggers for variable_name in self.variable_name_trigger_map: if ("info", None) in self.variable_name_trigger_map[variable_name]: self.variable_name_trigger_map[variable_name] = [("info", None)] assert ( variable_name in self.signature.args ), f"unexpected variable_name ({variable_name}) found in triggers" for variable_name in self.xtuneable_parameters: assert ( variable_name not in self.variable_name_trigger_map ), "trigger can't be an instance of XTuneParameter" def _parse_trigger(self, trigger: str) -> tuple[str, str, Callable]: split_trigger = trigger.split(_SEPARATOR) variable_name = split_trigger[0] if len(split_trigger) == 1: func_name = "info" func = None elif len(split_trigger) == 2: func_name = split_trigger[1] if func_name == "dtype": func = lambda tensor: tensor.dtype elif func_name in ["size()", "shape"]: func = lambda tensor: tensor.size() elif func_name == "stride()": func = lambda tensor: tensor.stride() elif func_name.startswith("size"): dim = int(func_name[5:][:-1]) func = lambda tensor: tensor.size(dim) elif func_name.startswith("shape"): dim = int(func_name[6:][:-1]) func = lambda tensor: tensor.size(dim) elif func_name.startswith("stride"): dim = int(func_name[7:][:-1]) func = lambda tensor: tensor.stride(dim) else: raise ValueError(f"unexpected triggeer found ({trigger})") return variable_name, func_name, func def __repr__(self): return f"""XTunedFunction( function_cache = {self.function_cache} configs = {self.configs} warmup iterations = {self.warmup_iterations} benchmark iterations = {self.benchmark_iterations} xtuneable parameters = {self.xtuneable_parameters} functional triggers = {self.functional_triggers} reset to zero = {self.reset_to_zero} function hash = {self.function_hash} )"""
[docs] def xtune( configs: list[XTuneConfig], triggers: set[str] = set(), functional_triggers: dict[str, Callable] = {}, warmup_iterations: int = _DEFAULT_WARMUP_ITERATIONS, benchmark_iterations: int = _BENCHMARK_ITERATIONS, reset_to_zero: dict = {}, ) -> XTunedFunction: """ autotuner for any function or kernel :param configs: list of configs to autotune over :type configs: list[XTuneConfig] :param triggers: change in these parameters will trigger autotuning :type triggers: set[str] :param functional_triggers: key, function mapping. change in the function outputs will trigger autotuning. :type functional_triggers: dict[str, Callable] :param warmup_iterations: iterations for warmup. Defaults to 5. :type warmup_iterations: int :param benchmark_iterations: iterations for benchmarking. Defaults to 10. :type benchmark_iterations: int :param reset_to_zero: A dictionary mapping tensor argument names to an optional callable condition. The specified tensors will be zeroed out after each benchmark iteration if the condition (if provided) returns True. :type reset_to_zero: dict :return: autotuned version of the function :rtype: _XTune """ def inner(function: Callable) -> Callable: return XTunedFunction( function=function, configs=configs, triggers=triggers, warmup_iterations=warmup_iterations, benchmark_iterations=benchmark_iterations, functional_triggers=functional_triggers, reset_to_zero=reset_to_zero, ) return inner