Source code for xma.functional.rnn

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch

from ...accelerator import KernelBackend
from ...custom_op import CustomOp, ctx_save_for_backward
from ...torch_utils import clip_gradients, tanh
from ...utils import empty_like_contiguous, is_triton_available, zeros_like_contiguous


if is_triton_available():
    from .triton_implementation import rnn_backward_triton, rnn_forward_triton


def _get_num_heads(x: torch.Tensor, W: torch.Tensor, run_check: bool) -> tuple[int, int, int]:
    Nx = x.size(-2)
    Nw = W.size(0)
    N = max(Nx, Nw)

    if run_check:
        assert N % Nx == 0
        assert N % Nw == 0

    return Nx, Nw, N


def _get_backward_tensor(y: torch.Tensor, Nx: int, N: int) -> torch.Tensor:
    if Nx == N:
        dx = empty_like_contiguous(y)
    else:
        x_shape = list(y.size())
        x_shape[-2] = Nx
        dx = torch.zeros(x_shape, device=y.device, dtype=torch.float32)

    return dx


class _RNN(CustomOp):
    @staticmethod
    def forward_backward_torch(
        x: torch.Tensor,
        W: torch.Tensor,
        h0: torch.Tensor | None,
        gradient_clipping: float | None,
        cu_seqlens: torch.Tensor | None,
        max_seqlen: int | None,
    ) -> torch.Tensor:
        Nx, Nw, N = _get_num_heads(x=x, W=W, run_check=False)

        y_shape = list(x.size())
        y_shape[-2] = N
        y = torch.empty(y_shape, device=x.device, dtype=x.dtype)

        if cu_seqlens is None:
            B, S, _, H = x.size()
        else:
            B = cu_seqlens.size(0) - 1
            S = max_seqlen.item() if isinstance(max_seqlen, torch.Tensor) else max_seqlen
            H = x.size(-1)

        Gx = N // Nx
        Gw = N // Nw

        x = x.repeat_interleave(Gx, dim=-2)
        W = W.repeat_interleave(Gw, dim=0)[None, ...]

        if h0 is None:
            h0 = torch.zeros(B, N, H, device=x.device, dtype=x.dtype)

        if cu_seqlens is not None:
            h0 = h0.clone()
            start = cu_seqlens[:-1]
            end = cu_seqlens[1:]

        for s in range(S):
            if cu_seqlens is None:
                h = h0[..., None, :] @ W + x[:, s, :, None, :]
            else:
                offset = start + s
                unfinished = offset < end
                offset_unfinished = offset[unfinished]

                h = h0[unfinished, :, None, :] @ W + x[offset_unfinished, :, None, :]

            h = tanh(h)
            h = h.squeeze(-2)
            h = clip_gradients(h, gradient_clipping)

            if cu_seqlens is None:
                y[:, s] = h
                h0 = h
            else:
                y[offset_unfinished] = h
                h0[unfinished] = h

        return y

    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,
        W: torch.Tensor,
        h0: torch.Tensor | None,
        gradient_clipping: float | None,
        cu_seqlens: torch.Tensor | None,
        max_seqlen: int | None,
        kernel_backend: KernelBackend,
    ) -> torch.Tensor:
        assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]

        Nx, _, N = _get_num_heads(x=x, W=W, run_check=False)
        y_shape = list(x.size())
        y_shape[-2] = N

        y = torch.empty(y_shape, device=x.device, dtype=x.dtype)

        rnn_forward_triton(
            x=x,
            W=W,
            h0=h0,
            y=y,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )

        ctx_save_for_backward(ctx, W, y, h0, cu_seqlens)
        ctx.max_seqlen = max_seqlen
        ctx.gradient_clipping = gradient_clipping
        ctx.Nx = Nx

        return y

    @staticmethod
    def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor]:
        W, y, h0, cu_seqlens = ctx.saved_tensors
        Nx = ctx.Nx
        N = y.size(-2)

        dx = _get_backward_tensor(y=y, Nx=Nx, N=N)
        dW = zeros_like_contiguous(W, dtype=torch.float32)
        dh0 = empty_like_contiguous(h0) if h0 is not None and h0.requires_grad else None

        rnn_backward_triton(
            W=W,
            y=y,
            h0=h0,
            dy=dy,
            dx=dx,
            dW=dW,
            dh0=dh0,
            cu_seqlens=cu_seqlens,
            max_seqlen=ctx.max_seqlen,
            gradient_clipping=ctx.gradient_clipping,
        )

        dx = dx.type_as(y)
        dW = dW.type_as(W)

        return dx, dW, dh0, *[None] * 4


[docs] def rnn( input: torch.Tensor, weight: torch.Tensor, input_state: torch.Tensor | None = None, gradient_clipping: float | None = None, cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, *, kernel_backend: KernelBackend | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ computes multihead RNN recurrent update over the sequence length: `tanh(input_state @ weight + input)` :param input: input tensor of shape (B, S, Nx, H) where Nx is the number of input heads and H is the head dimension. Should have shape (T, Nx, H) and `cu_seqlens` should be passed. :type input: torch.Tensor :param weight: weight tensor of shape (Nw, H, H) :type weight: torch.Tensor :param input_state: starting state of shape (B, N, H), where N = max{Nx, Nw}. None means starting state is 0 tensor. Defaults to None. :type input_state: torch.Tensor | None :param gradient_clipping: gradient clipping for the state gradient in backward, None implies no clipping. Defaults to None. :type gradient_clipping: float | None :param cu_seqlens: cumulative sequence length (must contain 0 as first element). Defaults to None. :type cu_seqlens: torch.Tensor | None :param max_seqlen: max sequence length in the batch. Defaults to None. :type max_seqlen: int | None :param kernel_backend: KernelBackend :type kernel_backend: KernelBackend | None :return: output tensor of shape (B, S, N, H) if `cu_seqlens` is None else (T, N, H) and output state of shape (B, N, H). :rtype: tuple[Tensor, Tensor] """ assert input.dim() == 3 + (cu_seqlens is None) if cu_seqlens is None: assert max_seqlen is None B, _, _, H = input.size() else: assert max_seqlen is not None assert cu_seqlens.dim() == 1 B = cu_seqlens.size(0) - 1 H = input.size(-1) _, Nw, N = _get_num_heads(x=input, W=weight, run_check=True) assert weight.size() == (Nw, H, H) if input_state is not None: assert input_state.size() == (B, N, H) if gradient_clipping is not None and gradient_clipping < 0: gradient_clipping = -gradient_clipping input = _RNN.run( x=input, W=weight, h0=input_state, gradient_clipping=gradient_clipping, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, kernel_backend=kernel_backend, ) input_state = input[:, -1] if cu_seqlens is None else input[cu_seqlens[1:] - 1] return input, input_state