xma.utils.debugging

print_gradient(x: Tensor, name: str) Tensor[source]

print gradient in backward (use only for debugging)

Parameters:
  • x (torch.Tensor) – input tensor

  • name (str) – additional metadata for the tensor (typically the name of the tensor)

Returns:

output tensor same as input tensor

Return type:

torch.Tensor