Source code for xma.optimizers.sgd.op

# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

import torch
from torch.optim.sgd import _multi_tensor_sgd, _single_tensor_sgd

from ...accelerator import Accelerator, KernelBackend
from ...constants import LOG_WARP_SIZE
from ...utils import is_triton_available


if is_triton_available():
    import triton.language as tl

    from .triton_implementation import multi_tensor_sgd_triton_kernel, single_tensor_sgd_triton

    _TORCH_TO_TRITON_DTYPE = {torch.float32: tl.float32, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16}


[docs] @torch.no_grad() def sgd( parameters: list[torch.Tensor], gradients: list[torch.Tensor], lr: float, maximize: bool, horizontal_fusion: bool, weight_decay: float, momentum: float, dampening: float, nesterov: bool, *, kernel_backend: KernelBackend | None = None, ) -> None: if kernel_backend is None: kernel_backend = Accelerator.get_kernel_backend() else: assert kernel_backend.verify_accelerator() if kernel_backend in [KernelBackend.cuda, KernelBackend.triton]: assert weight_decay == 0 assert momentum == 0 assert dampening == 0 assert not nesterov if horizontal_fusion: device = parameters[0].device NUM_WARPS = 8 multi_tensor_sgd_triton_kernel[len(parameters),]( W_ptr_ptr=torch.tensor([W.data_ptr() for W in parameters], dtype=torch.int64, device=device), dW_ptr_ptr=torch.tensor([dW.data_ptr() for dW in gradients], dtype=torch.int64, device=device), N_ptr=torch.tensor([W.numel() for W in parameters], dtype=torch.int64, device=device), lr=lr, BLOCK_SIZE=(NUM_WARPS << LOG_WARP_SIZE) * (16 // parameters[0].dtype.itemsize), MAXIMIZE=maximize, DTYPE=_TORCH_TO_TRITON_DTYPE[parameters[0].dtype], num_warps=NUM_WARPS, ) else: for W, dW in zip(parameters, gradients): assert W.is_contiguous() dW = dW.contiguous() single_tensor_sgd_triton(W=W, dW=dW, lr=lr, maximize=maximize) elif kernel_backend == KernelBackend.torch: (_multi_tensor_sgd if horizontal_fusion else _single_tensor_sgd)( params=parameters, grads=gradients, momentum_buffer_list=[None] * len(parameters), grad_scale=None, found_inf=None, weight_decay=weight_decay, momentum=momentum, lr=lr, dampening=dampening, nesterov=nesterov, maximize=maximize, has_sparse_grad=False, ) else: raise ValueError(f"unexpected kernel_backend ({kernel_backend})")