Source code for xma.functional.softmax
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
import torch
import torch.nn.functional as F
from ...accelerator import KernelBackend
from ...custom_op import CustomOp, ctx_save_for_backward
from ...utils import empty_like_contiguous, is_triton_available
if is_triton_available():
from .triton_implementation import softmax_backward_triton, softmax_forward_triton
class _Softmax(CustomOp):
@staticmethod
def forward_backward_torch(x: torch.Tensor, logits_multiplier: float | None) -> torch.Tensor:
dtype = x.dtype
x = x.float()
if logits_multiplier is not None:
x = x * logits_multiplier
x = F.softmax(x, dim=-1)
x = x.to(dtype)
return x
@staticmethod
def forward(ctx, x: torch.Tensor, logits_multiplier: float | None, kernel_backend: KernelBackend) -> torch.Tensor:
assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]
y = empty_like_contiguous(x)
softmax_forward_triton(x=x, y=y, logits_multiplier=logits_multiplier)
ctx_save_for_backward(ctx, y)
ctx.logits_multiplier = logits_multiplier
return y
@staticmethod
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
y = ctx.saved_tensors[0]
dx = empty_like_contiguous(y)
softmax_backward_triton(y=y, dy=dy, dx=dx, logits_multiplier=ctx.logits_multiplier)
return dx, None, None
[docs]
def softmax(
x: torch.Tensor, logits_multiplier: float | None = None, *, kernel_backend: KernelBackend | None = None
) -> torch.Tensor:
"""
computes softmax activation
:param x: input activation tensor
:type x: torch.Tensor
:param logits_multiplier: pre-multiplies `x` with `logits_multiplier` before computing softmax.
Defaults to None.
:type logits_multiplier: float | None
:param kernel_backend: KernelBackend
:type kernel_backend: KernelBackend | None
:return: output tensor
:rtype: Tensor
"""
# if 1D -> make 2D
is_flat = x.dim() == 1
if is_flat:
x = x[None, ...]
x = _Softmax.run(x=x, logits_multiplier=logits_multiplier, kernel_backend=kernel_backend)
# convert back to 1D
if is_flat:
x = x.squeeze(0)
return x