Source code for xma.functional.swiglu.pallas_implementation.forward
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
import torch
from torch_xla.experimental.custom_kernel import jax_import_guard, make_kernel_from_pallas
from ....custom_op import xma_op
from ....math import ceil_divide
jax_import_guard()
import jax
import jax.experimental.pallas as pl
import jax.experimental.pallas.tpu as pltpu
import jax.numpy as jnp
from jax.nn import sigmoid
[docs]
def swiglu_forward_pallas_kernel(g_ref, u_ref, y_ref):
g = g_ref[...]
u = u_ref[...]
dtype = g.dtype
g = g.astype(jnp.float32)
y = u * g * sigmoid(g)
y_ref[...] = y.astype(dtype)
@jax.jit
def swiglu_forward_pallas_jit(g: jax.Array, u: jax.Array) -> jax.Array:
B, H = g.shape
BLOCK_SIZE_H = min(ceil_divide(H, 128) * 128, 1024)
BLOCK_SIZE_B = min(1, 32 * 1024 * 1024 // (3 * BLOCK_SIZE_H * g.dtype.itemsize * 8)) << 3
kernel = pl.pallas_call(
swiglu_forward_pallas_kernel,
out_shape=jax.ShapeDtypeStruct(shape=g.shape, dtype=g.dtype),
grid=(ceil_divide(B, BLOCK_SIZE_B), ceil_divide(H, BLOCK_SIZE_H)),
in_specs=[
pl.BlockSpec(block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_H), index_map=lambda x, y: (x, y)),
pl.BlockSpec(block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_H), index_map=lambda x, y: (x, y)),
],
out_specs=pl.BlockSpec(block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_H), index_map=lambda x, y: (x, y)),
compiler_params=pltpu.CompilerParams(dimension_semantics=("parallel", "parallel")),
)
return kernel(g, u)
def _fake_function(g: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
assert g.is_contiguous()
assert u.is_contiguous()
return torch.empty_like(g)
@xma_op(mutates_args={}, fake_func=_fake_function)
def swiglu_forward_pallas(g: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
assert g.is_contiguous()
assert u.is_contiguous()
if swiglu_forward_pallas.cache is None:
swiglu_forward_pallas.cache = make_kernel_from_pallas(
swiglu_forward_pallas_jit, lambda g, u: [(g.shape, g.dtype)]
)
return swiglu_forward_pallas.cache(g, u)
swiglu_forward_pallas.cache = None