# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
import inspect
from collections import defaultdict
from typing import Any, Callable
import torch
from tqdm import tqdm
from ..utils import get_boolean_env_variable
from .cache import get_cutotune_cache
from .config import CutoTuneConfig
from .parameter import CutoTuneParameter
_DEBUG_CUTOTUNE = get_boolean_env_variable("DEBUG_CUTOTUNE", False)
_SEPARATOR = "."
_DEFAULT_WARMUP_ITERATIONS = 5
_BENCHMARK_ITERATIONS = 10
class _CutoTune:
def __init__(
self,
function: Callable,
configs: list[CutoTuneConfig],
triggers: set[str],
warmup_iterations: int,
benchmark_iterations: int,
functional_triggers: dict[str, Callable] = {},
reset_to_zero: dict = {},
) -> None:
assert len(configs) > 0, "no cutotune config is passed"
self.function = function
self.configs = configs
self.warmup_iterations = warmup_iterations
self.benchmark_iterations = benchmark_iterations
self.signature = inspect.getfullargspec(function)
self.cutotuneable_parameters = set(self.configs[0].get_key_values().keys())
self._setup_trigger_map(triggers)
for config in self.configs:
assert (
set(config.get_key_values().keys()) == self.cutotuneable_parameters
), "cutotune configs don't match the expected function signature"
self.functional_triggers = functional_triggers
self.reset_to_zero = reset_to_zero
self.filename = inspect.stack()[2].filename.rsplit("xma", 1)[1][1:]
self.function_hash = f"{self.filename}->{function.__name__}"
self.function_cache = {}
def __call__(self, *args, **kwargs) -> Any:
override_cutotune_parameters = self._check_all_or_no_args_are_cutotune_parameters(*args, **kwargs)
lookup_key = self._get_lookup_key(*args, **kwargs)
best_config = self.function_cache.get(lookup_key, None)
if best_config is None:
# bypass cutotune for single config
if len(self.configs) == 1:
best_config = self.configs[0]
best_time = 0
else:
best_config, best_time, _ = self._cutotune(*args, **kwargs)
self.function_cache[lookup_key] = best_config
get_cutotune_cache().add_config(
function_hash=self.function_hash, lookup_key=lookup_key, config=best_config
)
if _DEBUG_CUTOTUNE and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0):
print(
f"config {best_config} achieved the best time ({best_time} sec) for {lookup_key} for "
f"function {self.function.__name__}"
)
output = self.function(
**self._get_function_arguments(
config=best_config,
args=args,
kwargs=kwargs,
override_allowed=override_cutotune_parameters,
)
)
return output
def _check_all_or_no_args_are_cutotune_parameters(self, *args, **kwargs) -> bool:
num_cutotune_overrideables = 0
for i in range(len(args)):
variable_name = self.signature.args[i]
if isinstance(args[i], CutoTuneParameter) and variable_name in self.cutotuneable_parameters:
num_cutotune_overrideables += 1
# accessing kwargs.items() breaks torch.compile in backwards of a custom autograd function
for variable_name in kwargs:
if (
isinstance(kwargs.get(variable_name), CutoTuneParameter)
and variable_name in self.cutotuneable_parameters
):
num_cutotune_overrideables += 1
assert num_cutotune_overrideables in [
0,
len(self.cutotuneable_parameters),
], f"invalid number of CutoTuneParameter arguments, should be either 0 or {len(self.cutotuneable_parameters)}"
return num_cutotune_overrideables == 0
def _get_function_arguments(
self, config: CutoTuneConfig, args: list, kwargs: dict, override_allowed: bool
) -> dict:
# copy the best_config first so we can override with args or kwargs
result = {variable_name: value for variable_name, value in config.get_key_values().items()}
for i in range(len(args)):
variable_name = self.signature.args[i]
if override_allowed or variable_name not in result:
result[variable_name] = args[i]
# accessing kwargs.items() breaks torch.compile in backwards of a custom autograd function
for variable_name in kwargs:
if override_allowed or variable_name not in result:
result[variable_name] = kwargs.get(variable_name)
return result
@torch.compiler.set_stance("force_eager")
@torch.inference_mode()
def _cutotune(self, *args, **kwargs) -> tuple[CutoTuneConfig, float, list[tuple[CutoTuneConfig, float]]]:
best_config = None
best_time = float("inf")
configs = tqdm(self.configs) if _DEBUG_CUTOTUNE else self.configs
timed_configs = []
for config in configs:
if not config.is_condition_valid(
**self._get_function_arguments(
config=CutoTuneConfig({}), args=args, kwargs=kwargs, override_allowed=False
)
):
continue
elapsed_time = self._run_benchmark(
**self._get_function_arguments(config=config, args=args, kwargs=kwargs, override_allowed=False),
)
timed_configs.append((config, elapsed_time))
if elapsed_time < best_time:
best_config = config
best_time = elapsed_time
assert best_config is not None, "no best_config found, check that at least 1 cutotune config is valid"
return best_config, best_time, timed_configs
def _get_lookup_key(self, *args, **kwargs) -> Any:
lookup_key = []
def _maybe_add_key(variable_name: str, value) -> None:
if variable_name not in self.variable_name_trigger_map:
return
triggers = self.variable_name_trigger_map[variable_name]
if isinstance(value, torch.Tensor):
for func_name, func in triggers:
if func is None:
assert len(triggers) == 1
func = lambda tensor: (tensor.dtype, tensor.size(), tensor.stride())
lookup_key.append(f"{variable_name}.{func_name} = {func(value)}")
else:
assert len(triggers) == 1
func_name, func = triggers[0]
assert (
func is None
), f"trigger ({variable_name}) is not a tensor and shouldn't have a functional trigger"
lookup_key.append(f"{variable_name} = {value}")
for i, value in enumerate(args):
variable_name = self.signature.args[i]
_maybe_add_key(variable_name, value)
for variable_name in kwargs:
_maybe_add_key(variable_name, kwargs[variable_name])
# now run the functional triggers
if len(self.functional_triggers) > 0:
kwargs = self._get_function_arguments(
config=CutoTuneConfig({}), args=args, kwargs=kwargs, override_allowed=False
)
for variable_name, func in self.functional_triggers.items():
lookup_key.append(f"{variable_name} = {func(**kwargs)}")
return str(lookup_key)[1:-1]
def _run_benchmark(self, **kwargs: dict) -> float:
device_synchronize()
for _ in range(self.warmup_iterations):
self.function(**kwargs)
# TODO generalize to device Event
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if len(self.reset_to_zero) > 0:
elapsed_time = 0
for _ in range(self.benchmark_iterations):
start.record()
self.function(**kwargs)
end.record()
device_synchronize()
elapsed_time += start.elapsed_time(end)
for variable_name, function in self.reset_to_zero.items():
if function is None or function(**kwargs):
variable = kwargs[variable_name]
assert isinstance(variable, torch.Tensor)
variable.zero_()
else:
start.record()
for _ in range(self.benchmark_iterations):
self.function(**kwargs)
end.record()
device_synchronize()
elapsed_time = start.elapsed_time(end)
return elapsed_time / self.benchmark_iterations
def _setup_trigger_map(self, triggers: set[str]) -> None:
assert isinstance(triggers, set), "triggers should be a set"
self.variable_name_trigger_map = defaultdict(list)
for trigger in triggers:
variable_name, func_name, func = self._parse_trigger(trigger)
self.variable_name_trigger_map[variable_name].append((func_name, func))
# filter to remove all triggers if None, this is useful for Tensor based triggers
for variable_name in self.variable_name_trigger_map:
if ("info", None) in self.variable_name_trigger_map[variable_name]:
self.variable_name_trigger_map[variable_name] = [("info", None)]
assert (
variable_name in self.signature.args
), f"unexpected variable_name ({variable_name}) found in triggers"
for variable_name in self.cutotuneable_parameters:
assert (
variable_name not in self.variable_name_trigger_map
), "trigger can't be an instance of CutoTuneParameter"
def _parse_trigger(self, trigger: str) -> tuple[str, str, Callable]:
split_trigger = trigger.split(_SEPARATOR)
variable_name = split_trigger[0]
if len(split_trigger) == 1:
func_name = "info"
func = None
elif len(split_trigger) == 2:
func_name = split_trigger[1]
if func_name == "dtype":
func = lambda tensor: tensor.dtype
elif func_name in ["size()", "shape"]:
func = lambda tensor: tensor.size()
elif func_name == "stride()":
func = lambda tensor: tensor.stride()
elif func_name.startswith("size"):
dim = int(func_name[5:][:-1])
func = lambda tensor: tensor.size(dim)
elif func_name.startswith("shape"):
dim = int(func_name[6:][:-1])
func = lambda tensor: tensor.size(dim)
elif func_name.startswith("stride"):
dim = int(func_name[7:][:-1])
func = lambda tensor: tensor.stride(dim)
else:
raise ValueError(f"unexpected triggeer found ({trigger})")
return variable_name, func_name, func
def __repr__(self):
return self.function_cache
[docs]
def cutotune(
configs: list[CutoTuneConfig],
triggers: set[str] = set(),
functional_triggers: dict[str, Callable] = {},
warmup_iterations: int = _DEFAULT_WARMUP_ITERATIONS,
benchmark_iterations: int = _BENCHMARK_ITERATIONS,
reset_to_zero: dict = {},
) -> _CutoTune:
def inner(function: Callable) -> Callable:
return _CutoTune(
function=function,
configs=configs,
triggers=triggers,
warmup_iterations=warmup_iterations,
benchmark_iterations=benchmark_iterations,
functional_triggers=functional_triggers,
reset_to_zero=reset_to_zero,
)
return inner