xma.layers.m2rnn.op¶
- m2rnn(query: Tensor, key: Tensor, value: Tensor, weight: Tensor, forget_input: Tensor, input_state: Tensor | None = None, gradient_clipping: float | None = None, cu_seqlens: Tensor | None = None, max_seqlen: Tensor | int | None = None, *, kernel_backend: KernelBackend | None = None) tuple[Tensor, Tensor][source]¶
computes M2RNN recurrence
- Parameters:
query (torch.Tensor) – query tensor of shape (B, S, Nq, K) where Nq is the number of query heads and K is the key head dimension. Should have shape (T, Nq, K) and cu_seqlens should be passed.
key (torch.Tensor) – key tensor of shape (B, S, Nk, K) where Nk is the number of key heads and K is the key head dimension. Should have shape (T, Nk, K) and cu_seqlens should be passed.
value (torch.Tensor) – value tensor of shape (B, S, Nv, V) where Nv is the number of value heads and V is the value head dimension. Should have shape (T, Nv, V) and cu_seqlens should be passed.
weight (torch.Tensor) – weight tensor of shape (Nw, V, V)
forget_input (torch.Tensor) – forget input tensor of shape (B, S, Nxf) where Nxf is the number of forget heads and H is the head dimension. Should have shape (T, Nxf) and cu_seqlens should be passed.
input_state (torch.Tensor | None) – starting state of shape (B, N, K, V), where N = max{Nq, Nk, Nv, Nxf, Nw}. None means starting state is 0 tensor. Defaults to None.
gradient_clipping (float | None) – gradient clipping for the state gradient in backward, None implies no clipping. Defaults to None.
cu_seqlens (torch.Tensor | None) – cumulative sequence length (must contain 0 as first element). Defaults to None.
max_seqlen (int | None) – max sequence length in the batch. Defaults to None.
kernel_backend (KernelBackend | None) – KernelBackend
- Returns:
output tensor of shape (B, S, N, V) if cu_seqlens is None else (T, N, V) and output state of shape (B, N, K, V).
- Return type:
tuple[Tensor, Tensor]