150 lines
5.7 KiB
Python
150 lines
5.7 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch.utils.data import DataLoader
|
||
from torchvision import datasets, transforms, models
|
||
import matplotlib.pyplot as plt
|
||
from dmadam import DMAdam
|
||
import os
|
||
|
||
# ResNet-34 for CIFAR-10 (adapted from torchvision)
|
||
def get_resnet34():
|
||
model = models.resnet34(weights=None)
|
||
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||
model.maxpool = nn.Identity()
|
||
model.fc = nn.Linear(512, 10)
|
||
return model
|
||
|
||
# VGG-13 for CIFAR-10 (adapted from torchvision)
|
||
def get_vgg13():
|
||
model = models.vgg13(weights=None)
|
||
model.classifier[6] = nn.Linear(4096, 10)
|
||
return model
|
||
|
||
def train_epoch(model, loader, optimizer, criterion, device):
|
||
model.train()
|
||
correct = 0
|
||
total = 0
|
||
for data, target in loader:
|
||
data, target = data.to(device), target.to(device)
|
||
optimizer.zero_grad()
|
||
output = model(data)
|
||
loss = criterion(output, target)
|
||
loss.backward()
|
||
optimizer.step()
|
||
pred = output.argmax(dim=1)
|
||
correct += pred.eq(target).sum().item()
|
||
total += target.size(0)
|
||
return 100. * correct / total
|
||
|
||
def test(model, loader, device):
|
||
model.eval()
|
||
correct = 0
|
||
total = 0
|
||
with torch.no_grad():
|
||
for data, target in loader:
|
||
data, target = data.to(device), target.to(device)
|
||
output = model(data)
|
||
pred = output.argmax(dim=1)
|
||
correct += pred.eq(target).sum().item()
|
||
total += target.size(0)
|
||
return 100. * correct / total
|
||
|
||
def create_optimizer(name, params, config):
|
||
if name == 'SGD':
|
||
return torch.optim.SGD(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0))
|
||
elif name == 'SGDM':
|
||
return torch.optim.SGD(params, lr=config['lr'], momentum=config['momentum'], weight_decay=config.get('weight_decay', 0))
|
||
elif name == 'Adagrad':
|
||
return torch.optim.Adagrad(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0))
|
||
elif name == 'Adam':
|
||
return torch.optim.Adam(params, lr=config['lr'], betas=config['betas'], weight_decay=config.get('weight_decay', 0))
|
||
elif name == 'AdamW':
|
||
return torch.optim.AdamW(params, lr=config['lr'], betas=config['betas'], weight_decay=config['weight_decay'])
|
||
elif name == 'DMAdam':
|
||
return DMAdam(params, lr=config['lr'], eta0=config['eta0'], eps=config.get('eps', 1e-8))
|
||
|
||
def run_cifar10(model_name='ResNet-34'):
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
print(f"Using device: {device}")
|
||
|
||
transform_train = transforms.Compose([
|
||
transforms.RandomCrop(32, padding=4),
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||
])
|
||
transform_test = transforms.Compose([
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||
])
|
||
|
||
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
|
||
test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test)
|
||
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
|
||
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
|
||
|
||
# DMAdam parameters from Table 2
|
||
# VGG-13: η_k=0.9, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005
|
||
# ResNet-34: η_k=1.4, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005
|
||
dmadam_eta0 = 0.9 if model_name == 'VGG-13' else 1.4
|
||
|
||
optimizers_config = {
|
||
'SGD': {'lr': 0.1, 'weight_decay': 0.0005},
|
||
'SGDM': {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005},
|
||
'Adagrad': {'lr': 0.01, 'weight_decay': 0.0005},
|
||
'Adam': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005},
|
||
'AdamW': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005},
|
||
'DMAdam': {'lr': 0.001, 'eta0': dmadam_eta0, 'eps': 1e-8}
|
||
}
|
||
|
||
results = {}
|
||
criterion = nn.CrossEntropyLoss()
|
||
|
||
for opt_name, config in optimizers_config.items():
|
||
print(f"\nTraining CIFAR-10 ({model_name}) with {opt_name}")
|
||
|
||
if model_name == 'ResNet-34':
|
||
model = get_resnet34().to(device)
|
||
elif model_name == 'VGG-13':
|
||
model = get_vgg13().to(device)
|
||
|
||
optimizer = create_optimizer(opt_name, model.parameters(), config)
|
||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
|
||
|
||
test_accs = []
|
||
|
||
for epoch in range(200):
|
||
train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
|
||
test_acc = test(model, test_loader, device)
|
||
test_accs.append(test_acc)
|
||
scheduler.step()
|
||
|
||
if (epoch + 1) % 10 == 0:
|
||
print(f"Epoch {epoch+1}/200: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")
|
||
|
||
results[opt_name] = test_accs
|
||
print(f"{opt_name} Final Test Acc: {test_accs[-1]:.2f}%")
|
||
|
||
plt.figure(figsize=(10, 5))
|
||
for name, accs in results.items():
|
||
plt.plot(range(1, 201), accs, label=name)
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel('Test Accuracy (%)')
|
||
plt.title(f'CIFAR-10 Test Accuracy ({model_name})')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
os.makedirs('results', exist_ok=True)
|
||
plt.savefig(f'results/cifar10_{model_name.lower().replace("-", "")}_accuracy.png', dpi=140, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
return results
|
||
|
||
if __name__ == '__main__':
|
||
print("Running CIFAR-10 experiment with ResNet-34...")
|
||
run_cifar10('ResNet-34')
|
||
|
||
print("\n" + "="*50)
|
||
print("Running CIFAR-10 experiment with VGG-13...")
|
||
run_cifar10('VGG-13')
|