Source code for xma.optimizers.sgd.module
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
from __future__ import annotations
from typing import Callable
import torch
from torch.optim import SGD as _TorchSGD
from ...accelerator import KernelBackend
from .op import sgd
[docs]
class SGD(_TorchSGD):
[docs]
@torch.no_grad()
def step(self, closure: Callable | None = None, *, kernel_backend: KernelBackend | None = None) -> None:
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
parameters = []
gradients = []
for p in group["params"]:
if p.grad is None:
continue
parameters.append(p)
gradients.append(p.grad)
sgd(
parameters=parameters,
gradients=gradients,
lr=group["lr"],
maximize=False,
horizontal_fusion=group["foreach"],
weight_decay=group["weight_decay"],
momentum=group["momentum"],
dampening=group["dampening"],
nesterov=group["nesterov"],
kernel_backend=kernel_backend,
)
return loss