Source code for xma.functional.p_norm
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
import torch
from ...accelerator import Accelerator, KernelBackend
from ...utils import is_triton_available
if is_triton_available():
from .triton_implementation import p_norm_triton
[docs]
def p_norm(
x: torch.Tensor,
multiplier: float | None = None,
p: int | str = 2,
output_dtype: torch.dtype = torch.float32,
*,
kernel_backend: KernelBackend | None = None,
) -> torch.Tensor:
"""
computes norm of a vector
:param x: input activation
:type x: torch.Tensor
:param multiplier: if not None, pre-multiplies `x` with `multiplier`. Defaults to None.
:type multiplier: float | None
:param p: norm type. can be integer >= 1 or `inf`
:type p: int | str
:param output_dtype: output dtype
:type output_dtype: torch.dtype
:param kernel_backend: KernelBackend
:type kernel_backend: KernelBackend | None
:return: output activation
:rtype: Tensor
"""
assert x.dim() == 2
if kernel_backend is None:
kernel_backend = Accelerator.get_kernel_backend()
else:
assert kernel_backend.verify_accelerator()
if kernel_backend in [KernelBackend.cuda, KernelBackend.triton]:
B = x.size(0)
is_p_inf = p == "inf"
y = torch.empty(B, device=x.device, dtype=output_dtype)
p_norm_triton(x=x, y=y, multiplier=multiplier, p=None if is_p_inf else p, is_p_inf=is_p_inf)
elif kernel_backend == KernelBackend.torch:
if multiplier not in [None, 1]:
x = x * multiplier
if p == "inf":
y = x.abs().max(dim=-1)[0]
else:
y = torch.norm(x, p=p, dim=-1)
y = y.to(output_dtype)
else:
raise NotImplementedError(f"unexpected kernel_backend ({kernel_backend})")
return y