Source code for xma.cute_dsl_utils.math

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

import cutlass.cute as cute
from cutlass import Float32, Numeric, const_expr, range_constexpr
from cutlass._mlir.dialects import llvm
from cutlass.cute import TensorSSA
from cutlass.cutlass_dsl import T, dsl_user_op


@dsl_user_op
def _tanh(x: Float32 | float, *, loc=None, ip=None) -> Float32:
    return Float32(
        llvm.inline_asm(
            res=T.f32(),
            operands_=[Float32(x).ir_value(loc=loc, ip=ip)],
            asm_string="tanh.approx.f32 $0, $1;",
            constraints="=f,f",
            has_side_effects=False,
            is_align_stack=False,
            asm_dialect=llvm.AsmDialect.AD_ATT,
        )
    )


@cute.jit
def tanh(x: Numeric | TensorSSA, output_dtype: Numeric | None = None) -> Numeric | TensorSSA:
    if const_expr(output_dtype is None):
        output_dtype = x.dtype

    if const_expr(isinstance(x, TensorSSA)):
        y = cute.make_fragment(x.shape, Float32)
        y.store(x.to(Float32))

        for i in range_constexpr(cute.size(y.shape)):
            y[i] = _tanh(y[i])

        y = y.load()
    else:
        y = _tanh(x.to(Float32))
        y = y.to(output_dtype)

    return y


[docs] def sigmoid(x: Numeric | TensorSSA, output_dtype: Numeric | None = None) -> Numeric | TensorSSA: if const_expr(output_dtype is None): output_dtype = x.dtype x = x.to(Float32) x = 0.5 * tanh(0.5 * x, output_dtype=Float32) + 0.5 x = x.to(output_dtype) return x