Source code for xma.torch_utils

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch
import torch.nn.functional as F


[docs] def sigmoid(x: torch.Tensor) -> torch.Tensor: return F.sigmoid(x.float()).type_as(x)
[docs] def tanh(x: torch.Tensor) -> torch.Tensor: return F.tanh(x.float()).type_as(x)
class _ClipGradients(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, gradient_clipping: float) -> torch.Tensor: ctx.gradient_clipping = gradient_clipping return x @staticmethod def backward(ctx, x_grad: torch.Tensor) -> tuple[torch.Tensor, None]: gradient_clipping = ctx.gradient_clipping x_grad = x_grad.clip(-gradient_clipping, gradient_clipping) return x_grad, None
[docs] def clip_gradients(x: torch.Tensor, gradient_clipping: float | None) -> torch.Tensor: return x if gradient_clipping is None else _ClipGradients.apply(x, gradient_clipping)