xma.functional.fused_residual_add_rmsnorm¶
- fused_residual_add_rmsnorm(x: Tensor, residual: Tensor | None, weight: Tensor | None, eps: float | None, multiplier: float | None = None, memory_efficient: bool = False, deterministic: bool = False, *, kernel_backend: KernelBackend | None = None) tuple[Tensor, Tensor | None][source]¶
fused residual add RMSNorm computation
- Parameters:
x (torch.Tensor) – input activation
residual (torch.Tensor | None) – residual activation
weight (torch.Tensor | None) – RMSNorm weight
eps (float | None) – epsilon
multiplier (float | None) – if not None, pre-multiplies x with multiplier. Defaults to None.
memory_efficient (bool) – memory efficient = False caches RMSNorm’s denominator in the forward. Defaults to False.
deterministic (bool) – whether to use deterministic backward. Defaults to False.
kernel_backend (KernelBackend | None) – KernelBackend
- Returns:
output activations and updated residual stream
- Return type:
tuple[Tensor, Tensor | None]