# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
import torch
from ...accelerator import KernelBackend
from ...custom_op import CustomOp, ctx_needs_gradients, ctx_save_for_backward
from ...torch_utils import clip_gradients, sigmoid, tanh
from ...utils import empty_like_contiguous, is_triton_available, zeros_like_contiguous
from ..rnn import _get_backward_tensor
from .utils import _get_num_heads
if is_triton_available():
from .triton_implementation import gru_backward_triton, gru_forward_triton
class _GRU(CustomOp):
@staticmethod
def forward_backward_torch(
x: torch.Tensor,
W: torch.Tensor,
xf: torch.Tensor,
Wf: torch.Tensor,
xr: torch.Tensor,
Wr: torch.Tensor,
h0: torch.Tensor | None,
gradient_clipping: float | None,
cu_seqlens: torch.Tensor | None,
max_seqlen: int | None,
) -> torch.Tensor:
Nx, Nxf, Nxr, Nw, Nwf, Nwr, N = _get_num_heads(x=x, W=W, xf=xf, Wf=Wf, xr=xr, Wr=Wr, run_check=False)
y_shape = list(x.size())
y_shape[-2] = N
y = torch.empty(y_shape, device=x.device, dtype=x.dtype)
if cu_seqlens is None:
B, S, _, H = x.size()
else:
B = cu_seqlens.size(0) - 1
S = max_seqlen.item() if isinstance(max_seqlen, torch.Tensor) else max_seqlen
H = x.size(-1)
Gx = N // Nx
Gxf = N // Nxf
Gxr = N // Nxr
Gw = N // Nw
Gwf = N // Nwf
Gwr = N // Nwr
x = x.repeat_interleave(Gx, dim=-2)
xf = xf.repeat_interleave(Gxf, dim=-2)
xr = xr.repeat_interleave(Gxr, dim=-2)
W = W.repeat_interleave(Gw, dim=0)[None, ...]
Wf = Wf.repeat_interleave(Gwf, dim=0)[None, ...]
Wr = Wr.repeat_interleave(Gwr, dim=0)[None, ...]
if h0 is None:
h0 = torch.zeros(B, N, H, device=x.device, dtype=x.dtype)
if cu_seqlens is not None:
h0 = h0.clone()
start = cu_seqlens[:-1]
end = cu_seqlens[1:]
for s in range(S):
if cu_seqlens is None:
f = h0[..., None, :] @ Wf + xf[:, s, :, None, :]
r = h0[..., None, :] @ Wr + xr[:, s, :, None, :]
else:
offset = start + s
unfinished = offset < end
offset_unfinished = offset[unfinished]
f = h0[unfinished, :, None, :] @ Wf + xf[offset_unfinished, :, None, :]
r = h0[unfinished, :, None, :] @ Wr + xr[offset_unfinished, :, None, :]
f = sigmoid(f)
r = sigmoid(r)
if cu_seqlens is None:
z = (h0[..., None, :] * r) @ W + x[:, s, :, None, :]
else:
z = (h0[unfinished, :, None, :] * r) @ W + x[offset_unfinished, :, None, :]
z = tanh(z)
if cu_seqlens is None:
h = f * h0[..., None, :] + (1 - f) * z
else:
h = f * h0[unfinished, :, None, :] + (1 - f) * z
h = h.squeeze(-2)
h = clip_gradients(h, gradient_clipping)
if cu_seqlens is None:
y[:, s] = h
h0 = h
else:
y[offset_unfinished] = h
h0[unfinished] = h
return y
@staticmethod
def forward(
ctx,
x: torch.Tensor,
W: torch.Tensor,
xf: torch.Tensor,
Wf: torch.Tensor,
xr: torch.Tensor,
Wr: torch.Tensor,
h0: torch.Tensor | None,
gradient_clipping: float | None,
cu_seqlens: torch.Tensor | None,
max_seqlen: int | None,
kernel_backend: KernelBackend,
) -> torch.Tensor:
assert kernel_backend in [KernelBackend.cuda, KernelBackend.triton]
Nx, Nxf, Nxr, _, _, _, N = _get_num_heads(x=x, W=W, xf=xf, Wf=Wf, xr=xr, Wr=Wr, run_check=False)
y_shape = list(x.size())
y_shape[-2] = N
needs_grad = ctx_needs_gradients(ctx)
y = torch.empty(y_shape, device=x.device, dtype=x.dtype)
f = torch.empty(y_shape, device=x.device, dtype=x.dtype) if needs_grad and Nxf == N else None
r = torch.empty(y_shape, device=x.device, dtype=x.dtype) if needs_grad and Nxr == N else None
z = torch.empty(y_shape, device=x.device, dtype=x.dtype) if needs_grad and Nx == N else None
gru_forward_triton(
x=x,
W=W,
xf=xf,
Wf=Wf,
f=f,
xr=xr,
Wr=Wr,
r=r,
z=z,
h0=h0,
y=y,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
ctx_save_for_backward(
ctx,
W,
Wf,
f,
Wr,
r,
z,
y,
h0,
cu_seqlens,
x if z is None else None,
xf if f is None else None,
xr if r is None else None,
)
ctx.max_seqlen = max_seqlen
ctx.gradient_clipping = gradient_clipping
ctx.num_heads = Nx, Nxf, Nxr
return y
@staticmethod
def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor | None]:
W, Wf, f, Wr, r, z, y, h0, cu_seqlens, x, xf, xr = ctx.saved_tensors
Nx, Nxf, Nxr = ctx.num_heads
dx = _get_backward_tensor(y=y, Nx=Nx, N=y.size(-2))
dxf = _get_backward_tensor(y=y, Nx=Nxf, N=y.size(-2))
dxr = _get_backward_tensor(y=y, Nx=Nxr, N=y.size(-2))
dW = zeros_like_contiguous(W, dtype=torch.float32)
dWf = zeros_like_contiguous(Wf, dtype=torch.float32)
dWr = zeros_like_contiguous(Wr, dtype=torch.float32)
dh0 = empty_like_contiguous(h0) if h0 is not None and h0.requires_grad else None
gru_backward_triton(
x=x,
W=W,
y=y,
xf=xf,
Wf=Wf,
f=f,
dxf=dxf,
dWf=dWf,
xr=xr,
Wr=Wr,
r=r,
dxr=dxr,
dWr=dWr,
z=z,
h0=h0,
dy=dy,
dx=dx,
dW=dW,
dh0=dh0,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
gradient_clipping=ctx.gradient_clipping,
)
dx = dx.type_as(y)
dxf = dxf.type_as(y)
dxr = dxr.type_as(y)
dW = dW.type_as(W)
dWf = dWf.type_as(Wf)
dWr = dWr.type_as(Wr)
return dx, dW, dxf, dWf, dxr, dWr, dh0, *[None] * 4
[docs]
def gru(
input: torch.Tensor,
weight: torch.Tensor,
forget_input: torch.Tensor,
forget_weight: torch.Tensor,
reset_input: torch.Tensor,
reset_weight: torch.Tensor,
input_state: torch.Tensor | None = None,
gradient_clipping: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: int | None = None,
*,
kernel_backend: KernelBackend | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
computes multihead RNN: `tanh(input_state @ weight + input)`
:param input: input tensor of shape (B, S, Nx, H) where Nx is the number of input heads and H is the head
dimension. Should have shape (T, Nx, H) and `cu_seqlens` should be passed.
:type input: torch.Tensor
:param weight: weight tensor of shape (Nw, H, H)
:type weight: torch.Tensor
:param forget_input: forget input tensor of shape (B, S, Nxf, H) where Nxf is the number of input heads and H is the head
dimension. Should have shape (T, Nxf, H) and `cu_seqlens` should be passed.
:type forget_input: torch.Tensor
:param forget_weight: forget weight tensor of shape (NWf, H, H)
:type forget_weight: torch.Tensor
:param reset_input: reset input tensor of shape (B, S, Nxr, H) where Nxr is the number of input heads and H is the head
dimension. Should have shape (T, Nxr, H) and `cu_seqlens` should be passed.
:type reset_input: torch.Tensor
:param reset_weight: reset weight tensor of shape (Nwr, H, H)
:type reset_weight: torch.Tensor
:param input_state: starting state of shape (B, N, H), where N = max{Nx, Nw}. None means starting state is
0 tensor. Defaults to None.
:type input_state: torch.Tensor | None
:param gradient_clipping: gradient clipping for the state gradient in backward, None implies no clipping.
Defaults to None.
:type gradient_clipping: float | None
:param cu_seqlens: cumulative sequence length (must contain 0 as first element). Defaults to None.
:type cu_seqlens: torch.Tensor | None
:param max_seqlen: max sequence length in the batch. Defaults to None.
:type max_seqlen: int | None
:param kernel_backend: KernelBackend
:type kernel_backend: KernelBackend | None
:return: output tensor of shape (B, S, N, H) if `cu_seqlens` is None else (T, N, H) and output state of
shape (B, N, H).
:rtype: tuple[Tensor, Tensor]
"""
expected_dim = 3 + (cu_seqlens is None)
assert input.dim() == expected_dim
assert forget_input.dim() == expected_dim
assert reset_input.dim() == expected_dim
if cu_seqlens is None:
assert max_seqlen is None
B, _, _, H = input.size()
else:
assert max_seqlen is not None
assert cu_seqlens.dim() == 1
B = cu_seqlens.size(0) - 1
H = input.size(-1)
_, _, _, Nw, Nwf, Nwr, N = _get_num_heads(
x=input, W=weight, xf=forget_input, Wf=forget_weight, xr=reset_input, Wr=reset_weight, run_check=True
)
assert weight.size() == (Nw, H, H)
assert forget_weight.size() == (Nwf, H, H)
assert reset_weight.size() == (Nwr, H, H)
if input_state is not None:
assert input_state.size() == (B, N, H)
if gradient_clipping is not None and gradient_clipping < 0:
gradient_clipping = -gradient_clipping
input = _GRU.run(
x=input,
W=weight,
xf=forget_input,
Wf=forget_weight,
xr=reset_input,
Wr=reset_weight,
h0=input_state,
gradient_clipping=gradient_clipping,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
kernel_backend=kernel_backend,
)
input_state = input[:, -1] if cu_seqlens is None else input[cu_seqlens[1:] - 1]
return input, input_state