Source code for xma.utils.settings
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
import torch
[docs]
def get_triton_num_warps(BLOCK_SIZE: int) -> int:
# reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32 if torch.version.hip is None else 16
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8
return num_warps