xma.functional.bmm

bmm(A: Tensor, B: Tensor, C: Tensor | None, is_A_transposed: bool = False, is_B_transposed: bool = False, alpha: float = 1, beta: float = 1, *, kernel_backend: KernelBackend | None = None) Tensor[source]

computes alpha * (A @ B) + beta * C`

Parameters:
  • A (torch.Tensor) – A matrix

  • B (torch.Tensor) – B matrix

  • C (torch.Tensor | None) – C matrix, function returns A @ B if C is None

  • is_A_transposed (bool) – whether A has shape K x M. Defaults to False.

  • is_B_transposed (bool) – whether B has shape N x K. Defaults to False.

  • alpha (float) – alpha. Defaults to 1.

  • beta (float) – beta. Defaults to 1.

  • kernel_backend (KernelBackend | None) – KernelBackend

Returns:

output tensor

Return type:

Tensor