Source code for xma.functional.rmsnorm

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch

from ..accelerator import KernelBackend
from .fused_residual_add_rmsnorm import fused_residual_add_rmsnorm


[docs] def rmsnorm( x: torch.Tensor, weight: torch.Tensor | None, eps: float | None = None, memory_efficient: bool = False, deterministic: bool = False, *, kernel_backend: KernelBackend | None = None, ) -> torch.Tensor: """ RMSNorm computation :param x: input activation :type x: torch.Tensor :param weight: RMSNorm weight :type weight: torch.Tensor | None :param eps: epsilon. Defaults to None. :type eps: float | None :param memory_efficient: memory efficient = False caches RMSNorm's denominator in the forward. Defaults to False. :type memory_efficient: bool :param deterministic: whether to use deterministic backward. Defaults to False. :type deterministic: bool :param kernel_backend: KernelBackend :type kernel_backend: KernelBackend | None :return: output tensor :rtype: Tensor """ x, _ = fused_residual_add_rmsnorm( x=x, residual=None, weight=weight, eps=eps, multiplier=None, memory_efficient=memory_efficient, deterministic=deterministic, kernel_backend=kernel_backend, ) return x