Source code for xma.functional.bmm

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch

from ...accelerator import Accelerator, KernelBackend
from ...utils import is_triton_available


if is_triton_available():
    from .triton_implementation import bmm_triton


[docs] def bmm( A: torch.Tensor, B: torch.Tensor, C: torch.Tensor | None, is_A_transposed: bool = False, is_B_transposed: bool = False, alpha: float = 1, beta: float = 1, *, kernel_backend: KernelBackend | None = None, ) -> torch.Tensor: """ computes alpha * (A @ B) + beta * C` :param A: `A` matrix :type A: torch.Tensor :param B: `B` matrix :type B: torch.Tensor :param C: `C` matrix, function returns `A @ B` if `C` is None :type C: torch.Tensor | None :param is_A_transposed: whether `A` has shape K x M. Defaults to False. :type is_A_transposed: bool :param is_B_transposed: whether `B` has shape N x K. Defaults to False. :type is_B_transposed: bool :param alpha: alpha. Defaults to 1. :type alpha: float :param beta: beta. Defaults to 1. :type beta: float :param kernel_backend: KernelBackend :type kernel_backend: KernelBackend | None :return: output tensor :rtype: Tensor """ assert A.dim() == 3 assert B.dim() == 3 L, M, K = A.size() if is_A_transposed: M, K = K, M assert B.size(2 if is_B_transposed else 1) == K N = B.size(1 if is_B_transposed else 2) if beta == 0: assert C is None else: assert C is not None assert C.size() == (L, M, N) if kernel_backend is None: kernel_backend = Accelerator.get_kernel_backend() else: assert kernel_backend.verify_accelerator() if kernel_backend == KernelBackend.torch: if is_A_transposed: A = A.transpose(1, 2) if is_B_transposed: B = B.transpose(1, 2) if beta == 0: D = torch.bmm(A, B) if alpha != 1: D = alpha * D else: D = torch.baddbmm(C, A, B, alpha=alpha, beta=beta) elif kernel_backend in [KernelBackend.cuda, KernelBackend.triton]: D = torch.empty(L, M, N, dtype=A.dtype, device=A.device) bmm_triton( A=A, B=B, C=C, D=D, is_A_transposed=is_A_transposed, is_B_transposed=is_B_transposed, alpha=alpha, beta=beta, ) else: raise ValueError(f"unexpected kernel_backend ({kernel_backend})") return D