Source code for xma.layers.linear_attention.module

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch
import torch.nn as nn

from ...accelerator import KernelBackend
from ...math import divide_if_divisible
from .op import linear_attention


[docs] class LinearAttention(nn.Module): def __init__( self, input_size: int, key_head_dim: int, value_head_dim: int, output_size: int, num_query_heads: int, num_key_heads: int, num_value_heads: int, add_bias: bool, ) -> None: super().__init__() self.key_head_dim = key_head_dim self.value_head_dim = value_head_dim self.num_query_heads = num_query_heads self.num_key_heads = num_key_heads self.num_value_heads = num_value_heads self.num_heads = max(num_query_heads, num_key_heads, num_value_heads) divide_if_divisible(self.num_heads, self.num_query_heads) divide_if_divisible(self.num_heads, self.num_key_heads) divide_if_divisible(self.num_heads, self.num_value_heads) self.query_size = self.num_query_heads * self.key_head_dim self.key_size = self.num_key_heads * self.key_head_dim self.value_size = self.num_value_heads * self.value_head_dim self.state_size = self.num_heads * self.key_head_dim * self.value_head_dim self.input_projection = nn.Linear(input_size, self.query_size + self.key_size + self.value_size, bias=add_bias) self.output_projection = nn.Linear(self.num_heads * self.value_head_dim, output_size, bias=add_bias)
[docs] def forward( self, input: torch.Tensor, input_state: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, *, kernel_backend: KernelBackend | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: input = self.input_projection(input) query, key, value = input.split((self.query_size, self.key_size, self.value_size), dim=-1) query = query.view(*query.size()[:-1], self.num_query_heads, self.key_head_dim) key = key.view(*key.size()[:-1], self.num_key_heads, self.key_head_dim) value = value.view(*value.size()[:-1], self.num_value_heads, self.value_head_dim) if input_state is not None: input_state = input_state.view(-1, self.num_heads, self.key_head_dim, self.value_head_dim) input, input_state = linear_attention( query=query, key=key, value=value, input_state=input_state, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, kernel_backend=kernel_backend, ) input = input.flatten(-2, -1) input_state = input_state.flatten(-2, -1) input = self.output_projection(input) return input, input_state