Source code for xma.utils.debugging

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch


class _PrintGradient(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, name: str) -> torch.Tensor:
        ctx.name = name
        return x

    @staticmethod
    def backward(ctx, output_grad: torch.Tensor) -> tuple[torch.Tensor, None]:
        print(f"gradient for {ctx.name} = {output_grad}")
        return output_grad, None