xma.functional.gru

gru(input: Tensor, weight: Tensor, forget_input: Tensor, forget_weight: Tensor, reset_input: Tensor, reset_weight: Tensor, input_state: Tensor | None = None, gradient_clipping: float | None = None, cu_seqlens: Tensor | None = None, max_seqlen: Tensor | int | None = None, *, kernel_backend: KernelBackend | None = None) tuple[Tensor, Tensor][source]

computes multihead RNN: tanh(input_state @ weight + input)

Parameters:
  • input (torch.Tensor) – input tensor of shape (B, S, Nx, H) where Nx is the number of input heads and H is the head dimension. Should have shape (T, Nx, H) and cu_seqlens should be passed.

  • weight (torch.Tensor) – weight tensor of shape (Nw, H, H)

  • forget_input (torch.Tensor) – forget input tensor of shape (B, S, Nxf, H) where Nxf is the number of input heads and H is the head dimension. Should have shape (T, Nxf, H) and cu_seqlens should be passed.

  • forget_weight (torch.Tensor) – forget weight tensor of shape (NWf, H, H)

  • reset_input (torch.Tensor) – reset input tensor of shape (B, S, Nxr, H) where Nxr is the number of input heads and H is the head dimension. Should have shape (T, Nxr, H) and cu_seqlens should be passed.

  • reset_weight (torch.Tensor) – reset weight tensor of shape (Nwr, H, H)

  • input_state (torch.Tensor | None) – starting state of shape (B, N, H), where N = max{Nx, Nw}. None means starting state is 0 tensor. Defaults to None.

  • gradient_clipping (float | None) – gradient clipping for the state gradient in backward, None implies no clipping. Defaults to None.

  • cu_seqlens (torch.Tensor | None) – cumulative sequence length (must contain 0 as first element). Defaults to None.

  • max_seqlen (torch.Tensor | int | None) – max sequence length in the batch. Defaults to None.

  • kernel_backend (KernelBackend | None) – KernelBackend

Returns:

output tensor of shape (B, S, N, H) if cu_seqlens is None else (T, N, H) and output state of shape (B, N, H).

Return type:

tuple[Tensor, Tensor]