xma.layers.m2rnn¶
- class M2RNN(input_size: int, key_head_dim: int, value_head_dim: int, output_size: int, num_query_heads: int, num_key_heads: int, num_value_heads: int, num_forget_input_heads: int, num_weight_heads: int, add_bias: bool, gradient_clipping: float | None)[source]¶
Bases:
Module- forward(input: Tensor, input_state: Tensor | None = None, cu_seqlens: Tensor | None = None, max_seqlen: int | None = None, *, kernel_backend: KernelBackend | None = None) tuple[Tensor, Tensor][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- m2rnn(query: Tensor, key: Tensor, value: Tensor, weight: Tensor, forget_input: Tensor, input_state: Tensor | None = None, gradient_clipping: float | None = None, cu_seqlens: Tensor | None = None, max_seqlen: Tensor | int | None = None, *, kernel_backend: KernelBackend | None = None) tuple[Tensor, Tensor][source]¶
computes M2RNN recurrence
- Parameters:
query (torch.Tensor) – query tensor of shape (B, S, Nq, K) where Nq is the number of query heads and K is the key head dimension. Should have shape (T, Nq, K) and cu_seqlens should be passed.
key (torch.Tensor) – key tensor of shape (B, S, Nk, K) where Nk is the number of key heads and K is the key head dimension. Should have shape (T, Nk, K) and cu_seqlens should be passed.
value (torch.Tensor) – value tensor of shape (B, S, Nv, V) where Nv is the number of value heads and V is the value head dimension. Should have shape (T, Nv, V) and cu_seqlens should be passed.
weight (torch.Tensor) – weight tensor of shape (Nw, V, V)
forget_input (torch.Tensor) – forget input tensor of shape (B, S, Nxf) where Nxf is the number of forget heads and H is the head dimension. Should have shape (T, Nxf) and cu_seqlens should be passed.
input_state (torch.Tensor | None) – starting state of shape (B, N, K, V), where N = max{Nq, Nk, Nv, Nxf, Nw}. None means starting state is 0 tensor. Defaults to None.
gradient_clipping (float | None) – gradient clipping for the state gradient in backward, None implies no clipping. Defaults to None.
cu_seqlens (torch.Tensor | None) – cumulative sequence length (must contain 0 as first element). Defaults to None.
max_seqlen (int | None) – max sequence length in the batch. Defaults to None.
kernel_backend (KernelBackend | None) – KernelBackend
- Returns:
output tensor of shape (B, S, N, V) if cu_seqlens is None else (T, N, V) and output state of shape (B, N, K, V).
- Return type:
tuple[Tensor, Tensor]