Source code for xma.cute_dsl_utils.utils
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
from ..utils import get_alignment
[docs]
def torch_tensor_to_cute_tensor(x: torch.Tensor, leading_dim: int) -> cute.Tensor:
x = x.detach()
x = from_dlpack(x, assumed_align=get_alignment(x))
# not sure if there is a better way to check PyTorch's broadcasting
if x.stride[leading_dim] == 0:
leading_dim = None
x = x.mark_layout_dynamic(leading_dim=leading_dim)
return x