Source code for xma.layers.m2rnn.op

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from functools import partial

import torch

from ...accelerator import KernelBackend
from ...custom_op import CustomOp, ctx_save_for_backward
from ...torch_utils import clip_gradients, tanh
from ...utils import empty_like_contiguous, is_triton_available, zeros_like_contiguous
from .utils import _get_num_heads


if is_triton_available():
    from .triton_implementation import m2rnn_backward_triton, m2rnn_forward_triton


class _M2RNN(CustomOp):
    @staticmethod
    def forward_backward_torch(
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        W: torch.Tensor,
        xf: torch.Tensor,
        h0: torch.Tensor | None,
        gradient_clipping: float | None,
        cu_seqlens: torch.Tensor | None,
        max_seqlen: int | None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        Nq, Nk, Nv, Nw, Nxf, N = _get_num_heads(q=q, k=k, v=v, W=W, xf=xf, run_check=False)

        V = v.size(-1)

        if cu_seqlens is None:
            B, S, _, K = q.size()
            y = torch.empty(B, S, N, K, V, device=q.device, dtype=q.dtype)
        else:
            B = cu_seqlens.size(0) - 1
            S = max_seqlen.item() if isinstance(max_seqlen, torch.Tensor) else max_seqlen
            T, _, K = q.size()

            y = torch.empty(T, N, K, V, device=q.device, dtype=q.dtype)

        if h0 is None:
            h0 = torch.zeros(B, N, K, V, device=k.device, dtype=k.dtype)

        Gq = N // Nq
        Gk = N // Nk
        Gv = N // Nv

        Gw = N // Nw
        Gxf = N // Nxf

        q = q.repeat_interleave(Gq, dim=-2)
        k = k.repeat_interleave(Gk, dim=-2)
        v = v.repeat_interleave(Gv, dim=-2)
        W = W.repeat_interleave(Gw, dim=0)
        xf = xf.repeat_interleave(Gxf, dim=-1)

        # (B, S, N, K, V) = (B, S, N, K, 1) * (B, S, N, 1, V)
        x = k[..., None] * v[..., None, :]
        W = W[None, ...]

        if cu_seqlens is not None:
            h0 = h0.clone()
            start = cu_seqlens[:-1]
            end = cu_seqlens[1:]

        for s in range(S):
            if cu_seqlens is None:
                f = xf[:, s, :, None, None]
                # (B, N, K, V) = (B, N, K, V) @ (1, N, V, V) + (B, N, K, V)
                h = h0 @ W + x[:, s]
            else:
                offset = start + s
                unfinished = offset < end
                offset_unfinished = offset[unfinished]

                f = xf[offset_unfinished, :, None, None]
                # (B, N, K, V) = (B, N, K, V) @ (1, N, V, V) + (B, N, K, V)
                h = h0[unfinished] @ W + x[offset_unfinished]

            h = tanh(h)

            if cu_seqlens is None:
                h = f * h0 + (1 - f) * h
            else:
                h = f * h0[unfinished] + (1 - f) * h

            h = clip_gradients(h, gradient_clipping)

            if cu_seqlens is None:
                y[:, s] = h
                h0 = h
            else:
                y[offset_unfinished] = h
                h0[unfinished] = h

        y = q[..., None, :] @ y
        y = y.squeeze(-2)

        return y, h0

    @staticmethod
    def forward(
        ctx,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        W: torch.Tensor,
        xf: torch.Tensor,
        h0: torch.Tensor | None,
        gradient_clipping: float | None,
        cu_seqlens: torch.Tensor | None,
        max_seqlen: int | None,
        kernel_backend: KernelBackend,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]

        Nq, Nk, Nv, Nw, Nxf, N = _get_num_heads(q=q, k=k, v=v, W=W, xf=xf, run_check=False)

        if cu_seqlens is None:
            B = k.size(0)
        else:
            B = cu_seqlens.size(0) - 1

        K = k.size(-1)
        V = v.size(-1)

        ht = torch.empty(B, N, K, V, device=k.device, dtype=k.dtype)

        y_shape = list(v.size())
        y_shape[-2] = N
        y = torch.empty(y_shape, device=q.device, dtype=q.dtype)

        m2rnn_forward_triton(
            q=q,
            k=k,
            v=v,
            W=W,
            xf=xf,
            h0=h0,
            h=None,
            ht=ht,
            y=y,
            cu_seqlens=cu_seqlens,
            Nq=Nq,
            Nk=Nk,
            Nv=Nv,
            Nw=Nw,
            Nxf=Nxf,
            N=N,
        )

        ctx_save_for_backward(ctx, q, k, v, W, xf, h0, cu_seqlens)

        ctx.gradient_clipping = gradient_clipping
        ctx.num_heads = Nq, Nk, Nv, Nw, Nxf, N

        y = y.type_as(v)

        return y, ht

    @staticmethod
    def backward(ctx, dy: torch.Tensor, dht: torch.Tensor) -> tuple[torch.Tensor | None]:
        q, k, v, W, xf, h0, cu_seqlens = ctx.saved_tensors
        Nq, Nk, Nv, Nw, Nxf, N = ctx.num_heads

        V = v.size(-1)

        if cu_seqlens is None:
            B, S, _, K = q.size()
            h = torch.empty(B, S, N, K, V, dtype=q.dtype, device=q.device)
        else:
            T, _, K = q.size()
            h = torch.empty(T, N, K, V, dtype=q.dtype, device=q.device)

        m2rnn_forward_triton(
            q=None,
            k=k,
            v=v,
            W=W,
            xf=xf,
            h0=h0,
            h=h,
            ht=None,
            y=None,
            cu_seqlens=cu_seqlens,
            Nq=Nq,
            Nk=Nk,
            Nv=Nv,
            Nw=Nw,
            Nxf=Nxf,
            N=N,
        )

        function = partial(zeros_like_contiguous, dtype=torch.float32)

        dq = (empty_like_contiguous if Nq == N else function)(q)
        dk = (empty_like_contiguous if Nk == N else function)(k)
        dv = (empty_like_contiguous if Nv == N else function)(v)
        dW = zeros_like_contiguous(W, dtype=torch.float32)
        dxf = (empty_like_contiguous if Nxf == N else function)(xf)
        dh0 = empty_like_contiguous(h0) if h0 is not None and h0.requires_grad else None

        m2rnn_backward_triton(
            q=q,
            k=k,
            v=v,
            W=W,
            xf=xf,
            h0=h0,
            dy=dy,
            h=h,
            dq=dq,
            dk=dk,
            dv=dv,
            dW=dW,
            dxf=dxf,
            dh0=dh0,
            cu_seqlens=cu_seqlens,
            gradient_clipping=ctx.gradient_clipping,
        )

        dq = dq.type_as(q)
        dk = dk.type_as(k)
        dv = dv.type_as(v)
        dW = dW.type_as(W)
        dxf = dxf.type_as(xf)

        return dq, dk, dv, dW, dxf, dh0, *[None] * 4


[docs] def m2rnn( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, weight: torch.Tensor, forget_input: torch.Tensor, input_state: torch.Tensor | None = None, gradient_clipping: float | None = None, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | int | None = None, *, kernel_backend: KernelBackend | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ computes M2RNN recurrence :param query: query tensor of shape (B, S, Nq, K) where Nq is the number of query heads and K is the key head dimension. Should have shape (T, Nq, K) and `cu_seqlens` should be passed. :type query: torch.Tensor :param key: key tensor of shape (B, S, Nk, K) where Nk is the number of key heads and K is the key head dimension. Should have shape (T, Nk, K) and `cu_seqlens` should be passed. :type key: torch.Tensor :param value: value tensor of shape (B, S, Nv, V) where Nv is the number of value heads and V is the value head dimension. Should have shape (T, Nv, V) and `cu_seqlens` should be passed. :type value: torch.Tensor :param weight: weight tensor of shape (Nw, V, V) :type weight: torch.Tensor :param forget_input: forget input tensor of shape (B, S, Nxf) where Nxf is the number of forget heads and H is the head dimension. Should have shape (T, Nxf) and `cu_seqlens` should be passed. :type forget_input: torch.Tensor :param input_state: starting state of shape (B, N, K, V), where N = max{Nq, Nk, Nv, Nxf, Nw}. None means starting state is 0 tensor. Defaults to None. :type input_state: torch.Tensor | None :param gradient_clipping: gradient clipping for the state gradient in backward, None implies no clipping. Defaults to None. :type gradient_clipping: float | None :param cu_seqlens: cumulative sequence length (must contain 0 as first element). Defaults to None. :type cu_seqlens: torch.Tensor | None :param max_seqlen: max sequence length in the batch. Defaults to None. :type max_seqlen: int | None :param kernel_backend: KernelBackend :type kernel_backend: KernelBackend | None :return: output tensor of shape (B, S, N, V) if `cu_seqlens` is None else (T, N, V) and output state of shape (B, N, K, V). :rtype: tuple[Tensor, Tensor] """ if cu_seqlens is None: assert max_seqlen is None B, S, _, K = query.size() else: assert max_seqlen is not None assert cu_seqlens.dim() == 1 B = cu_seqlens.size(0) - 1 T, _, K = query.size() V = value.size(-1) Nq, Nk, Nv, Nw, Nxf, N = _get_num_heads(q=query, k=key, v=value, W=weight, xf=forget_input, run_check=True) if cu_seqlens is None: assert query.size() == (B, S, Nq, K) assert key.size() == (B, S, Nk, K) assert value.size() == (B, S, Nv, V) assert forget_input.size() == (B, S, Nxf) else: assert query.size() == (T, Nq, K) assert key.size() == (T, Nk, K) assert value.size() == (T, Nv, V) assert forget_input.size() == (T, Nxf) assert weight.size() == (Nw, V, V) if input_state is not None: assert input_state.size() == (B, N, K, V) if gradient_clipping is not None and gradient_clipping < 0: gradient_clipping = -gradient_clipping output, input_state = _M2RNN.run( q=query, k=key, v=value, W=weight, xf=forget_input, h0=input_state, gradient_clipping=gradient_clipping, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, kernel_backend=kernel_backend, ) return output, input_state