xma.functional.rnn¶
- rnn(input: Tensor, weight: Tensor, input_state: Tensor | None = None, gradient_clipping: float | None = None, cu_seqlens: Tensor | None = None, max_seqlen: Tensor | int | None = None, *, kernel_backend: KernelBackend | None = None) tuple[Tensor, Tensor][source]¶
computes multihead RNN recurrent update over the sequence length: tanh(input_state @ weight + input)
- Parameters:
input (torch.Tensor) – input tensor of shape (B, S, Nx, H) where Nx is the number of input heads and H is the head dimension. Should have shape (T, Nx, H) and cu_seqlens should be passed.
weight (torch.Tensor) – weight tensor of shape (Nw, H, H)
input_state (torch.Tensor | None) – starting state of shape (B, N, H), where N = max{Nx, Nw}. None means starting state is 0 tensor. Defaults to None.
gradient_clipping (float | None) – gradient clipping for the state gradient in backward, None implies no clipping. Defaults to None.
cu_seqlens (torch.Tensor | None) – cumulative sequence length (must contain 0 as first element). Defaults to None.
max_seqlen (torch.Tensor | int | None) – max sequence length in the batch. Defaults to None.
kernel_backend (KernelBackend | None) – KernelBackend
- Returns:
output tensor of shape (B, S, N, H) if cu_seqlens is None else (T, N, H) and output state of shape (B, N, H).
- Return type:
tuple[Tensor, Tensor]