xma.layers.m2rnn.op

m2rnn(query: Tensor, key: Tensor, value: Tensor, weight: Tensor, forget_input: Tensor, input_state: Tensor | None = None, gradient_clipping: float | None = None, cu_seqlens: Tensor | None = None, max_seqlen: Tensor | int | None = None, *, kernel_backend: KernelBackend | None = None) tuple[Tensor, Tensor][source]

computes M2RNN recurrence

Parameters:
  • query (torch.Tensor) – query tensor of shape (B, S, Nq, K) where Nq is the number of query heads and K is the key head dimension. Should have shape (T, Nq, K) and cu_seqlens should be passed.

  • key (torch.Tensor) – key tensor of shape (B, S, Nk, K) where Nk is the number of key heads and K is the key head dimension. Should have shape (T, Nk, K) and cu_seqlens should be passed.

  • value (torch.Tensor) – value tensor of shape (B, S, Nv, V) where Nv is the number of value heads and V is the value head dimension. Should have shape (T, Nv, V) and cu_seqlens should be passed.

  • weight (torch.Tensor) – weight tensor of shape (Nw, V, V)

  • forget_input (torch.Tensor) – forget input tensor of shape (B, S, Nxf) where Nxf is the number of forget heads and H is the head dimension. Should have shape (T, Nxf) and cu_seqlens should be passed.

  • input_state (torch.Tensor | None) – starting state of shape (B, N, K, V), where N = max{Nq, Nk, Nv, Nxf, Nw}. None means starting state is 0 tensor. Defaults to None.

  • gradient_clipping (float | None) – gradient clipping for the state gradient in backward, None implies no clipping. Defaults to None.

  • cu_seqlens (torch.Tensor | None) – cumulative sequence length (must contain 0 as first element). Defaults to None.

  • max_seqlen (int | None) – max sequence length in the batch. Defaults to None.

  • kernel_backend (KernelBackend | None) – KernelBackend

Returns:

output tensor of shape (B, S, N, V) if cu_seqlens is None else (T, N, V) and output state of shape (B, N, K, V).

Return type:

tuple[Tensor, Tensor]