Source code for xma.optimizers.sgd.module

# **************************************************
# Copyright (c) 2026, Mayank Mishra
# **************************************************

from __future__ import annotations

from typing import Callable

import torch
from torch.optim import SGD as _TorchSGD

from ...accelerator import KernelBackend
from .op import sgd


[docs] class SGD(_TorchSGD):
[docs] @torch.no_grad() def step(self, closure: Callable | None = None, *, kernel_backend: KernelBackend | None = None) -> None: loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: params = [] grads = [] momentum_buffer_list = [] has_sparse_grad = self._init_group( group=group, params=params, grads=grads, momentum_buffer_list=momentum_buffer_list ) assert not has_sparse_grad sgd( params=params, grads=grads, momentum_buffer_list=momentum_buffer_list, lr=group["lr"], weight_decay=group["weight_decay"], momentum=group["momentum"], dampening=group["dampening"], nesterov=group["nesterov"], maximize=group["maximize"], foreach=group["foreach"], kernel_backend=kernel_backend, ) if group["momentum"] != 0: for p, m in zip(params, momentum_buffer_list, strict=True): self.state[p]["momentum_buffer"] = m return loss