xma.functional.fused_residual_add_rmsnorm

fused_residual_add_rmsnorm(x: Tensor, residual: Tensor | None, weight: Tensor | None, eps: float | None, multiplier: float | None = None, memory_efficient: bool = False, deterministic: bool = False, *, kernel_backend: KernelBackend | None = None) tuple[Tensor, Tensor | None][source]

fused residual add RMSNorm computation

Parameters:
  • x (torch.Tensor) – input activation

  • residual (torch.Tensor | None) – residual activation

  • weight (torch.Tensor | None) – RMSNorm weight

  • eps (float | None) – epsilon

  • multiplier (float | None) – if not None, pre-multiplies x with multiplier. Defaults to None.

  • memory_efficient (bool) – memory efficient = False caches RMSNorm’s denominator in the forward. Defaults to False.

  • deterministic (bool) – whether to use deterministic backward. Defaults to False.

  • kernel_backend (KernelBackend | None) – KernelBackend

Returns:

output activations and updated residual stream

Return type:

tuple[Tensor, Tensor | None]