Files
benchmarkexp/dmadam.py
2026-02-04 16:49:52 +08:00

94 lines
3.2 KiB
Python

import torch
from torch.optim import Optimizer
import math
class DMAdam(Optimizer):
"""
Implements DMAdam algorithm.
References:
- Algorithm 1 in "DMAdam: Dual averaging enhanced adaptive gradient method for deep neural networks"
- Theoretical constraints from Theorem 1 & 3 (Dynamic eta and beta).
"""
def __init__(self, params, lr=1e-3, eta0=1.0, eps=1e-8):
"""
Args:
params: Parameters to optimize.
lr (float): Corresponds to alpha_k in the paper (Learning Rate).
eta0 (float): The initial scaling factor.
In Theory mode: eta_k = eta0 / sqrt(k).
eps (float): Term added to the denominator.
weight_decay (float): Weight decay (L2 penalty).
"""
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eta0:
raise ValueError("Invalid eta0 value: {}".format(eta0))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
defaults = dict(lr=lr, eta0=eta0, eps=eps)
super(DMAdam, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
alpha = group['lr']
eta0 = group['eta0']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['m'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['v'] = torch.zeros_like(p, memory_format=torch.preserve_format)
m = state['m']
v = state['v']
state['step'] += 1
k = state['step']
# 1. Dynamic Beta_k = 1 - 1/sqrt(k)
beta_k = 1.0 - 1.0 / math.sqrt(k)
# 2. Dynamic Eta_k = eta0 / sqrt(k)
eta_k = eta0 / math.sqrt(k)
# lambda_k = alpha_k * sqrt(k + 1)
lambda_k = alpha * math.sqrt(k + 1)
# Line 3: Update m^k
# m^k = (m^{k-1} + lambda_k * g^k) / sqrt(k + 1)
m.add_(grad, alpha=lambda_k)
m.div_(math.sqrt(k + 1))
# Line 4: Update v^k
# epsilon_0 = eps / (1 - beta_k)
eps_0 = eps / (1.0 - beta_k)
# v^k = beta_k * v^{k-1} + (1 - beta_k) * ((g^k)^2 + epsilon_0)
term = grad.pow(2).add_(eps_0)
v.mul_(beta_k).add_(term, alpha=(1.0 - beta_k))
# Line 5: Update x^{k+1}
# x^{k+1} = x^k - eta_k * (m^k / sqrt(v^k + eps))
denominator = v.sqrt().add(eps)
# p = p - eta_k * (m / denom)
p.addcdiv_(m, denominator, value=-eta_k)
return loss