Source code for xma.functional.sequence_packing

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from typing import Sequence

import torch

from ...accelerator import KernelBackend
from ...custom_op import CustomOp, ctx_save_for_backward
from ...utils import is_triton_available
from .cuda_implementation import pack_unpack_sequence_cuda


if is_triton_available():
    from .triton_implementation import pack_unpack_sequence_triton


class _PackSequence(CustomOp):
    @staticmethod
    def forward_backward_torch(
        x: torch.Tensor, cu_seqlens: torch.Tensor, output_shape: tuple[int], padding_side: str
    ) -> torch.Tensor:
        B, S = x.size()[:2]
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        batch_indices = torch.arange(B, device=x.device).repeat_interleave(seqlens)

        if padding_side == "left":
            pad_tokens = S - seqlens
            seq_indices = torch.cat([torch.arange(sl, S, device=x.device) for sl in pad_tokens])
        elif padding_side == "right":
            seq_indices = torch.cat([torch.arange(sl, device=x.device) for sl in seqlens])
        else:
            raise ValueError(f"unexpected padding_side ({padding_side})")

        x = x[batch_indices, seq_indices]

        return x

    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        output_shape: tuple[int],
        padding_side: str,
        kernel_backend: KernelBackend,
    ) -> torch.Tensor:
        ctx.kernel_backend = kernel_backend

        if kernel_backend == KernelBackend.cuda:
            x = x.contiguous()
            cu_seqlens = cu_seqlens.contiguous()

        ctx_save_for_backward(ctx, cu_seqlens)
        ctx.padding_side = padding_side
        ctx.x_shape = x.size()

        y = torch.empty(output_shape, device=x.device, dtype=x.dtype)

        if kernel_backend == KernelBackend.cuda:
            pack_unpack_sequence_cuda(
                x=x, output=y, cu_seqlens=cu_seqlens, padding_side=padding_side, pack=True, BLOCK_SIZE=1024
            )
        elif kernel_backend == KernelBackend.triton:
            pack_unpack_sequence_triton(x=x, output=y, cu_seqlens=cu_seqlens, padding_side=padding_side, pack=True)
        else:
            raise ValueError(f"unexpected kernel_backend ({kernel_backend})")

        return y

    @staticmethod
    def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None]:
        kernel_backend = ctx.kernel_backend
        cu_seqlens = ctx.saved_tensors[0]

        dx = torch.zeros(*ctx.x_shape, device=dy.device, dtype=dy.dtype)

        if kernel_backend == KernelBackend.cuda:
            dy = dy.contiguous()

            pack_unpack_sequence_cuda(
                x=dy,
                output=dx,
                cu_seqlens=cu_seqlens,
                padding_side=ctx.padding_side,
                pack=False,
                BLOCK_SIZE=1024,
            )
        elif kernel_backend == KernelBackend.triton:
            pack_unpack_sequence_triton(
                x=dy, output=dx, cu_seqlens=cu_seqlens, padding_side=ctx.padding_side, pack=False
            )
        else:
            raise ValueError(f"unexpected kernel_backend ({kernel_backend})")

        return dx, *[None] * 4


class _UnpackSequence(CustomOp):
    @staticmethod
    def forward_backward_torch(
        x: torch.Tensor, cu_seqlens: torch.Tensor, output_shape: tuple[int], padding_side: str
    ) -> torch.Tensor:
        B = cu_seqlens.size(0) - 1
        S = output_shape[1]

        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        batch_indices = torch.arange(B, device=x.device).repeat_interleave(seqlens)

        if padding_side == "left":
            pad_tokens = S - seqlens
            seq_indices = torch.cat([torch.arange(sl, S, device=x.device) for sl in pad_tokens])
        elif padding_side == "right":
            seq_indices = torch.cat([torch.arange(sl, device=x.device) for sl in seqlens])
        else:
            raise ValueError(f"unexpected padding_side ({padding_side})")

        padded = torch.zeros(output_shape, dtype=x.dtype, device=x.device)
        padded[batch_indices, seq_indices] = x

        return padded

    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        output_shape: tuple[int],
        padding_side: str,
        kernel_backend: KernelBackend,
    ) -> torch.Tensor:
        ctx.kernel_backend = kernel_backend

        if kernel_backend == KernelBackend.cuda:
            x = x.contiguous()
            cu_seqlens = cu_seqlens.contiguous()

        ctx_save_for_backward(ctx, cu_seqlens)
        ctx.padding_side = padding_side
        ctx.x_shape = x.size()

        y = torch.zeros(*output_shape, device=x.device, dtype=x.dtype)

        if kernel_backend == KernelBackend.cuda:
            pack_unpack_sequence_cuda(
                x=x, output=y, cu_seqlens=cu_seqlens, padding_side=padding_side, pack=False, BLOCK_SIZE=1024
            )
        elif kernel_backend == KernelBackend.triton:
            pack_unpack_sequence_triton(x=x, output=y, cu_seqlens=cu_seqlens, padding_side=padding_side, pack=False)
        else:
            raise ValueError(f"unexpected padding_side ({padding_side})")

        return y

    @staticmethod
    def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None]:
        kernel_backend = ctx.kernel_backend
        padding_side = ctx.padding_side
        cu_seqlens = ctx.saved_tensors[0]

        dx = torch.empty(ctx.x_shape, device=dy.device, dtype=dy.dtype)

        if kernel_backend == KernelBackend.cuda:
            dy = dy.contiguous()

            pack_unpack_sequence_cuda(
                x=dy,
                output=dx,
                cu_seqlens=cu_seqlens,
                padding_side=padding_side,
                pack=True,
                BLOCK_SIZE=1024,
            )
        elif kernel_backend == KernelBackend.triton:
            pack_unpack_sequence_triton(x=dy, output=dx, cu_seqlens=cu_seqlens, padding_side=padding_side, pack=True)
        else:
            raise ValueError(f"unexpected padding_side ({padding_side})")

        return dx, *[None] * 4


[docs] def pack_sequence( inputs: Sequence[torch.Tensor], cu_seqlens: torch.Tensor, total_tokens: int, padding_side: str = "left", *, kernel_backend: KernelBackend | None = None, ) -> Sequence[torch.Tensor]: """ pack tensors :param inputs: list of tensors :type inputs: Sequence[torch.Tensor] :param cu_seqlens: cumulative sequence length :type cu_seqlens: torch.Tensor :param total_tokens: total number of tokens :type total_tokens: int :param padding_side: padding side :type padding_side: str :param kernel_backend: KernelBackend :type kernel_backend: KernelBackend | None :return: list of packed tensors :rtype: Sequence[Tensor] """ assert padding_side in ["left", "right"] assert isinstance(inputs, (list, tuple)) outputs = [] for x in inputs: assert x.dim() >= 2 assert x.size(0) == cu_seqlens.size(0) - 1 x = _PackSequence.run( x=x, cu_seqlens=cu_seqlens, output_shape=(total_tokens, *x.size()[2:]), padding_side=padding_side, kernel_backend=kernel_backend, ) outputs.append(x) return outputs
[docs] def unpack_sequence( inputs: Sequence[torch.Tensor], cu_seqlens: torch.Tensor, batch_size: int, sequence_length: int, padding_side: str = "left", *, kernel_backend: KernelBackend | None = None, ) -> Sequence[torch.Tensor]: """ unpack tensors :param inputs: list of tensors :type inputs: Sequence[torch.Tensor] :param cu_seqlens: cumulative sequence length :type cu_seqlens: torch.Tensor :param batch_size: batch size :type batch_size: int :param sequence_length: sequence length :type sequence_length: int :param padding_side: padding side :type padding_side: str :param kernel_backend: KernelBackend :type kernel_backend: KernelBackend | None :return: list of unpacked tensors :rtype: Sequence[Tensor] """ assert padding_side in ["left", "right"] assert isinstance(inputs, (list, tuple)) outputs = [] for x in inputs: assert x.dim() >= 2 assert cu_seqlens.size(0) - 1 == batch_size x = _UnpackSequence.run( x=x, cu_seqlens=cu_seqlens, output_shape=(batch_size, sequence_length, *x.size()[1:]), padding_side=padding_side, kernel_backend=kernel_backend, ) outputs.append(x) return outputs