Source code for xma.functional.cross_entropy

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch
import torch.nn.functional as F

from ...accelerator import KernelBackend
from ...custom_op import CustomOp, ctx_needs_gradients, ctx_save_for_backward
from ...utils import empty_like_contiguous, get_num_elements_and_hidden_size, is_triton_available


if is_triton_available():
    from .triton_implementation import cross_entropy_forward_backward_triton


class _CrossEntropy(CustomOp):
    @staticmethod
    def forward_backward_torch(
        x: torch.Tensor, labels: torch.Tensor, reduction: str = "mean", logits_multiplier: float | None = None
    ) -> torch.Tensor:
        x = x.float()

        if logits_multiplier not in [None, 1]:
            x = x * logits_multiplier

        return F.cross_entropy(x, labels, reduction=reduction)

    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,
        labels: torch.Tensor,
        reduction: str,
        logits_multiplier: float | None,
        kernel_backend: KernelBackend,
    ) -> torch.Tensor:
        assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]

        loss = torch.zeros((), device=x.device, dtype=torch.float32)
        x_grad = empty_like_contiguous(x) if ctx_needs_gradients(ctx) else None

        cross_entropy_forward_backward_triton(
            x=x, labels=labels, loss=loss, x_grad=x_grad, logits_multiplier=logits_multiplier, reduction=reduction
        )

        ctx_save_for_backward(ctx, x_grad)

        return loss

    @staticmethod
    def backward(ctx, output_grad: torch.Tensor) -> tuple[torch.Tensor, None, None, None]:
        x_grad = ctx.saved_tensors[0]
        x_grad *= output_grad

        return x_grad, *[None] * 4


[docs] def cross_entropy( x: torch.Tensor, labels: torch.Tensor, reduction: str = "mean", logits_multiplier: float | None = None, *, kernel_backend: KernelBackend | None = None, ) -> torch.Tensor: """ cross entropy loss :param x: logits :type x: torch.Tensor :param labels: labels :type labels: torch.Tensor :param reduction: reduction method: "sum", "mean" or None :type reduction: str :param logits_multiplier: logits multiplier pre-multiplies logits, None implies 1. Defaults to None. :type logits_multiplier: float | None :param kernel_backend: KernelBackend :type kernel_backend: KernelBackend | None :return: loss :rtype: Tensor """ assert reduction in ["sum", "mean"] assert x.dim() == 2, "x should be 2 dimensional" assert labels.dim() == 1, "labels should be 1 dimensional" assert ( labels.size(0) == get_num_elements_and_hidden_size(x)[0] ), "x and labels have different number of elements along batch dimension" x = _CrossEntropy.run( x=x, labels=labels, reduction=reduction, logits_multiplier=logits_multiplier, kernel_backend=kernel_backend ) return x