# **************************************************
# Copyright (c) 2026, Mayank Mishra
# **************************************************
import torch
from ..accelerator import KernelBackend
from ..custom_op import CustomOp, ctx_save_for_backward
from ..math import divide_if_divisible
from ..utils import empty_like_contiguous, is_torch_xla_available
from .swiglu import _FUNCTIONS, swiglu
if is_torch_xla_available():
from .swiglu import _swiglu_backward_pallas, _swiglu_forward_pallas
class _SwigluPacked(CustomOp):
@staticmethod
def forward_backward_torch(x: torch.Tensor) -> torch.Tensor:
up, gate = x.chunk(2, dim=-1)
return swiglu(gate=gate, up=up, kernel_backend=KernelBackend.torch)
@staticmethod
def forward(ctx, x: torch.Tensor, kernel_backend: KernelBackend) -> torch.Tensor:
ctx.kernel_backend = kernel_backend
if kernel_backend in [KernelBackend.cuda, KernelBackend.pallas]:
x = x.contiguous()
ctx_save_for_backward(ctx, x)
u, g = x.chunk(2, dim=-1)
if kernel_backend == KernelBackend.pallas:
return _swiglu_forward_pallas(g=g, u=u)
y = torch.empty(*x.size()[:-1], divide_if_divisible(x.size(-1), 2), device=x.device, dtype=x.dtype)
forward_function, backward_function = _FUNCTIONS[kernel_backend]
forward_function(g=g, u=u, y=y)
ctx.backward_function = backward_function
return y
@staticmethod
def backward(ctx, dy: torch.Tensor) -> torch.Tensor:
kernel_backend = ctx.kernel_backend
x = ctx.saved_tensors[0]
if kernel_backend in [KernelBackend.cuda, KernelBackend.pallas]:
dy = dy.contiguous()
u, g = x.chunk(2, dim=-1)
if kernel_backend == KernelBackend.pallas:
dg, du = _swiglu_backward_pallas(g=g, u=u, dy=dy)
dx = torch.cat([du, dg], dim=-1)
elif kernel_backend == KernelBackend.nki:
du = empty_like_contiguous(u)
dg = empty_like_contiguous(g)
ctx.backward_function(g=g, u=u, dy=dy, dg=dg, du=du)
dx = torch.cat([du, dg], dim=-1)
else:
dx = empty_like_contiguous(x)
du, dg = dx.chunk(2, dim=-1)
ctx.backward_function(g=g, u=u, dy=dy, dg=dg, du=du)
return dx, None
[docs]
def swiglu_packed(x: torch.Tensor, *, kernel_backend: KernelBackend | None = None) -> torch.Tensor:
"""
computes swiglu activation by splitting the tensor `x` into 2 parts: gate and up activations
:param x: input activation
:type x: torch.Tensor
:param kernel_backend: KernelBackend
:type kernel_backend: KernelBackend | None
:return: output tensor
:rtype: Tensor
"""
original_shape = x.size()
x = x.flatten(0, -2)
H = divide_if_divisible(original_shape[-1], 2)
y = _SwigluPacked.run(x=x, kernel_backend=kernel_backend)
y = y.view(*original_shape[:-1], H)
return y