Files
benchmarkexp/dl-main.py

150 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from dmadam import DMAdam
import os
# ResNet-34 for CIFAR-10 (adapted from torchvision)
def get_resnet34():
model = models.resnet34(weights=None)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(512, 10)
return model
# VGG-13 for CIFAR-10 (adapted from torchvision)
def get_vgg13():
model = models.vgg13(weights=None)
model.classifier[6] = nn.Linear(4096, 10)
return model
def train_epoch(model, loader, optimizer, criterion, device):
model.train()
correct = 0
total = 0
for data, target in loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
return 100. * correct / total
def test(model, loader, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
return 100. * correct / total
def create_optimizer(name, params, config):
if name == 'SGD':
return torch.optim.SGD(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0))
elif name == 'SGDM':
return torch.optim.SGD(params, lr=config['lr'], momentum=config['momentum'], weight_decay=config.get('weight_decay', 0))
elif name == 'Adagrad':
return torch.optim.Adagrad(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0))
elif name == 'Adam':
return torch.optim.Adam(params, lr=config['lr'], betas=config['betas'], weight_decay=config.get('weight_decay', 0))
elif name == 'AdamW':
return torch.optim.AdamW(params, lr=config['lr'], betas=config['betas'], weight_decay=config['weight_decay'])
elif name == 'DMAdam':
return DMAdam(params, lr=config['lr'], eta0=config['eta0'], eps=config.get('eps', 1e-8))
def run_cifar10(model_name='ResNet-34'):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
# DMAdam parameters from Table 2
# VGG-13: η_k=0.9, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005
# ResNet-34: η_k=1.4, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005
dmadam_eta0 = 0.9 if model_name == 'VGG-13' else 1.4
optimizers_config = {
'SGD': {'lr': 0.001, 'weight_decay': 0.0005},
'SGDM': {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.0005},
'Adagrad': {'lr': 0.001, 'weight_decay': 0.0005},
'Adam': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005},
'AdamW': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005},
'DMAdam': {'lr': 0.001, 'eta0': dmadam_eta0, 'eps': 1e-8}
}
results = {}
criterion = nn.CrossEntropyLoss()
for opt_name, config in optimizers_config.items():
print(f"\nTraining CIFAR-10 ({model_name}) with {opt_name}")
if model_name == 'ResNet-34':
model = get_resnet34().to(device)
elif model_name == 'VGG-13':
model = get_vgg13().to(device)
optimizer = create_optimizer(opt_name, model.parameters(), config)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
test_accs = []
for epoch in range(200):
train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
test_acc = test(model, test_loader, device)
test_accs.append(test_acc)
scheduler.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/200: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")
results[opt_name] = test_accs
print(f"{opt_name} Final Test Acc: {test_accs[-1]:.2f}%")
plt.figure(figsize=(10, 5))
for name, accs in results.items():
plt.plot(range(1, 201), accs, label=name)
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy (%)')
plt.title(f'CIFAR-10 Test Accuracy ({model_name})')
plt.legend()
plt.grid(True)
os.makedirs('results', exist_ok=True)
plt.savefig(f'results/cifar10_{model_name.lower().replace("-", "")}_accuracy.png', dpi=140, bbox_inches='tight')
plt.close()
return results
if __name__ == '__main__':
print("Running CIFAR-10 experiment with ResNet-34...")
run_cifar10('ResNet-34')
print("\n" + "="*50)
print("Running CIFAR-10 experiment with VGG-13...")
run_cifar10('VGG-13')