Source code for xma.cutotune.cache

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import os

import yaml

from ..utils import get_boolean_env_variable
from .config import CutoTuneConfig


_LOAD_CUTOTUNE_CACHE = get_boolean_env_variable("LOAD_CUTOTUNE_CACHE", True)
_CUTOTUNE_CACHE_FILENAME = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cache.yml")


class _CutoTuneCache:
    def __init__(self) -> None:
        self.cache = {}

        if _LOAD_CUTOTUNE_CACHE and os.path.exists(_CUTOTUNE_CACHE_FILENAME):
            cache = yaml.load(open(_CUTOTUNE_CACHE_FILENAME, "r"), yaml.SafeLoader)
            self.cache = self._deserialize(cache)

    def add_config(self, function_hash: str, lookup_key: str, config: CutoTuneConfig) -> None:
        if function_hash not in self.cache:
            self.cache[function_hash] = {}

        self.cache[function_hash][lookup_key] = config

    def get_config(self, function_hash: str, lookup_key: str) -> CutoTuneConfig:
        if function_hash in self.cache:
            function_cache = self.cache[function_hash]
            return function_cache.get(lookup_key, None)

        return None

    def save(self) -> None:
        yaml.dump(self._serialize(self.cache), open(_CUTOTUNE_CACHE_FILENAME, "w"))

    def _serialize(self, x: dict) -> dict:
        result = {}

        for function_hash in x:
            function_cache = x[function_hash]
            result[function_hash] = {}

            for lookup_key, config in function_cache.items():
                result[function_hash][lookup_key] = {key: value for key, value in config.get_key_values().items()}

        return result

    def _deserialize(self, x: dict) -> dict:
        result = {}

        for function_hash in x:
            function_cache = x[function_hash]
            result[function_hash] = {}

            for lookup_key, config in function_cache.items():
                result[function_hash][lookup_key] = CutoTuneConfig({key: value for key, value in config.items()})

        return result


_CUTOTUNE_CACHE = None


[docs] def get_cutotune_cache() -> _CutoTuneCache: global _CUTOTUNE_CACHE if _CUTOTUNE_CACHE is None: _CUTOTUNE_CACHE = _CutoTuneCache() return _CUTOTUNE_CACHE