import torch from torch.optim import Optimizer import math class DMAdam(Optimizer): """ Implements DMAdam algorithm. References: - Algorithm 1 in "DMAdam: Dual averaging enhanced adaptive gradient method for deep neural networks" - Theoretical constraints from Theorem 1 & 3 (Dynamic eta and beta). """ def __init__(self, params, lr=1e-3, eta0=1.0, eps=1e-8): """ Args: params: Parameters to optimize. lr (float): Corresponds to alpha_k in the paper (Learning Rate). eta0 (float): The initial scaling factor. In Theory mode: eta_k = eta0 / sqrt(k). eps (float): Term added to the denominator. weight_decay (float): Weight decay (L2 penalty). """ if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eta0: raise ValueError("Invalid eta0 value: {}".format(eta0)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) defaults = dict(lr=lr, eta0=eta0, eps=eps) super(DMAdam, self).__init__(params, defaults) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: alpha = group['lr'] eta0 = group['eta0'] eps = group['eps'] for p in group['params']: if p.grad is None: continue grad = p.grad state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 state['m'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['v'] = torch.zeros_like(p, memory_format=torch.preserve_format) m = state['m'] v = state['v'] state['step'] += 1 k = state['step'] # 1. Dynamic Beta_k = 1 - 1/sqrt(k) beta_k = 1.0 - 1.0 / math.sqrt(k) # 2. Dynamic Eta_k = eta0 / sqrt(k) eta_k = eta0 / math.sqrt(k) # lambda_k = alpha_k * sqrt(k + 1) lambda_k = alpha * math.sqrt(k + 1) # Line 3: Update m^k # m^k = (m^{k-1} + lambda_k * g^k) / sqrt(k + 1) m.add_(grad, alpha=lambda_k) m.div_(math.sqrt(k + 1)) # Line 4: Update v^k # epsilon_0 = eps / (1 - beta_k) eps_0 = eps / (1.0 - beta_k) # v^k = beta_k * v^{k-1} + (1 - beta_k) * ((g^k)^2 + epsilon_0) term = grad.pow(2).add_(eps_0) v.mul_(beta_k).add_(term, alpha=(1.0 - beta_k)) # Line 5: Update x^{k+1} # x^{k+1} = x^k - eta_k * (m^k / sqrt(v^k + eps)) denominator = v.sqrt().add(eps) # p = p - eta_k * (m / denom) p.addcdiv_(m, denominator, value=-eta_k) return loss