update
This commit is contained in:
94
dmadam.py
Normal file
94
dmadam.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
import math
|
||||
|
||||
class DMAdam(Optimizer):
|
||||
"""
|
||||
Implements DMAdam algorithm.
|
||||
|
||||
References:
|
||||
- Algorithm 1 in "DMAdam: Dual averaging enhanced adaptive gradient method for deep neural networks"
|
||||
- Theoretical constraints from Theorem 1 & 3 (Dynamic eta and beta).
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, eta0=1.0, eps=1e-8):
|
||||
"""
|
||||
Args:
|
||||
params: Parameters to optimize.
|
||||
lr (float): Corresponds to alpha_k in the paper (Learning Rate).
|
||||
eta0 (float): The initial scaling factor.
|
||||
In Theory mode: eta_k = eta0 / sqrt(k).
|
||||
eps (float): Term added to the denominator.
|
||||
weight_decay (float): Weight decay (L2 penalty).
|
||||
"""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eta0:
|
||||
raise ValueError("Invalid eta0 value: {}".format(eta0))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
|
||||
defaults = dict(lr=lr, eta0=eta0, eps=eps)
|
||||
super(DMAdam, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
alpha = group['lr']
|
||||
eta0 = group['eta0']
|
||||
eps = group['eps']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
grad = p.grad
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['m'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
state['v'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
m = state['m']
|
||||
v = state['v']
|
||||
state['step'] += 1
|
||||
k = state['step']
|
||||
|
||||
# 1. Dynamic Beta_k = 1 - 1/sqrt(k)
|
||||
beta_k = 1.0 - 1.0 / math.sqrt(k)
|
||||
|
||||
# 2. Dynamic Eta_k = eta0 / sqrt(k)
|
||||
eta_k = eta0 / math.sqrt(k)
|
||||
|
||||
# lambda_k = alpha_k * sqrt(k + 1)
|
||||
lambda_k = alpha * math.sqrt(k + 1)
|
||||
|
||||
# Line 3: Update m^k
|
||||
# m^k = (m^{k-1} + lambda_k * g^k) / sqrt(k + 1)
|
||||
m.add_(grad, alpha=lambda_k)
|
||||
m.div_(math.sqrt(k + 1))
|
||||
|
||||
# Line 4: Update v^k
|
||||
# epsilon_0 = eps / (1 - beta_k)
|
||||
eps_0 = eps / (1.0 - beta_k)
|
||||
|
||||
# v^k = beta_k * v^{k-1} + (1 - beta_k) * ((g^k)^2 + epsilon_0)
|
||||
term = grad.pow(2).add_(eps_0)
|
||||
v.mul_(beta_k).add_(term, alpha=(1.0 - beta_k))
|
||||
|
||||
# Line 5: Update x^{k+1}
|
||||
# x^{k+1} = x^k - eta_k * (m^k / sqrt(v^k + eps))
|
||||
denominator = v.sqrt().add(eps)
|
||||
|
||||
# p = p - eta_k * (m / denom)
|
||||
p.addcdiv_(m, denominator, value=-eta_k)
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user