94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
import torch
|
|
from torch.optim import Optimizer
|
|
import math
|
|
|
|
class DMAdam(Optimizer):
|
|
"""
|
|
Implements DMAdam algorithm.
|
|
|
|
References:
|
|
- Algorithm 1 in "DMAdam: Dual averaging enhanced adaptive gradient method for deep neural networks"
|
|
- Theoretical constraints from Theorem 1 & 3 (Dynamic eta and beta).
|
|
"""
|
|
|
|
def __init__(self, params, lr=1e-3, eta0=1.0, eps=1e-8):
|
|
"""
|
|
Args:
|
|
params: Parameters to optimize.
|
|
lr (float): Corresponds to alpha_k in the paper (Learning Rate).
|
|
eta0 (float): The initial scaling factor.
|
|
In Theory mode: eta_k = eta0 / sqrt(k).
|
|
eps (float): Term added to the denominator.
|
|
weight_decay (float): Weight decay (L2 penalty).
|
|
"""
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= eta0:
|
|
raise ValueError("Invalid eta0 value: {}".format(eta0))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
|
|
defaults = dict(lr=lr, eta0=eta0, eps=eps)
|
|
super(DMAdam, self).__init__(params, defaults)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
alpha = group['lr']
|
|
eta0 = group['eta0']
|
|
eps = group['eps']
|
|
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
continue
|
|
|
|
grad = p.grad
|
|
|
|
state = self.state[p]
|
|
|
|
# State initialization
|
|
if len(state) == 0:
|
|
state['step'] = 0
|
|
state['m'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
state['v'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
|
m = state['m']
|
|
v = state['v']
|
|
state['step'] += 1
|
|
k = state['step']
|
|
|
|
# 1. Dynamic Beta_k = 1 - 1/sqrt(k)
|
|
beta_k = 1.0 - 1.0 / math.sqrt(k)
|
|
|
|
# 2. Dynamic Eta_k = eta0 / sqrt(k)
|
|
eta_k = eta0 / math.sqrt(k)
|
|
|
|
# lambda_k = alpha_k * sqrt(k + 1)
|
|
lambda_k = alpha * math.sqrt(k + 1)
|
|
|
|
# Line 3: Update m^k
|
|
# m^k = (m^{k-1} + lambda_k * g^k) / sqrt(k + 1)
|
|
m.add_(grad, alpha=lambda_k)
|
|
m.div_(math.sqrt(k + 1))
|
|
|
|
# Line 4: Update v^k
|
|
# epsilon_0 = eps / (1 - beta_k)
|
|
eps_0 = eps / (1.0 - beta_k)
|
|
|
|
# v^k = beta_k * v^{k-1} + (1 - beta_k) * ((g^k)^2 + epsilon_0)
|
|
term = grad.pow(2).add_(eps_0)
|
|
v.mul_(beta_k).add_(term, alpha=(1.0 - beta_k))
|
|
|
|
# Line 5: Update x^{k+1}
|
|
# x^{k+1} = x^k - eta_k * (m^k / sqrt(v^k + eps))
|
|
denominator = v.sqrt().add(eps)
|
|
|
|
# p = p - eta_k * (m / denom)
|
|
p.addcdiv_(m, denominator, value=-eta_k)
|
|
|
|
return loss |