Source code for xma.functional.swiglu.nki_implementation.forward

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import neuronxcc.nki.language as nl
import torch
from torch_neuronx import TorchNeuronNKIKernel

from ....custom_op import xma_op
from ....math import ceil_divide


[docs] def swiglu_forward_nki_kernel(g_ptr, u_ptr, y_ptr): BLOCK_SIZE_B = 128 BLOCK_SIZE_H = 512 B, H = g_ptr.shape BLOCK_ID_B = nl.program_id(0) BLOCK_ID_H = nl.program_id(1) BLOCK_B = BLOCK_ID_B * BLOCK_SIZE_B + nl.arange(BLOCK_SIZE_B)[:, None] BLOCK_H = BLOCK_ID_H * BLOCK_SIZE_H + nl.arange(BLOCK_SIZE_H)[None, :] MASK_B = BLOCK_B < B MASK_H = BLOCK_H < H MASK = MASK_B & MASK_H g = nl.load(g_ptr[BLOCK_B, BLOCK_H], mask=MASK) u = nl.load(u_ptr[BLOCK_B, BLOCK_H], mask=MASK) g = nl.copy(g, dtype=nl.tfloat32) y = u * g * nl.sigmoid(g) nl.store(y_ptr[BLOCK_B, BLOCK_H], y, mask=MASK)
@xma_op(mutates_args={"y"}) def swiglu_forward_nki(g: torch.Tensor, u: torch.Tensor, y: torch.Tensor) -> None: BLOCK_SIZE_B = 128 BLOCK_SIZE_H = 512 B, H = g.size() compile_key = (B, H, g.dtype) kernel = swiglu_forward_nki.cache.get(compile_key, None) if kernel is None: kernel = TorchNeuronNKIKernel( func=swiglu_forward_nki_kernel, grid=(ceil_divide(B, BLOCK_SIZE_B), ceil_divide(H, BLOCK_SIZE_H)), kernel_return=False, return_tensor_overrides=(y,), ) swiglu_forward_nki.cache[compile_key] = kernel kernel(g, u, y) swiglu_forward_nki.cache = {}