update: experiments on CIFAR10
This commit is contained in:
149
dl-main.py
Normal file
149
dl-main.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms, models
|
||||
import matplotlib.pyplot as plt
|
||||
from dmadam import DMAdam
|
||||
import os
|
||||
|
||||
# ResNet-34 for CIFAR-10 (adapted from torchvision)
|
||||
def get_resnet34():
|
||||
model = models.resnet34(weights=None)
|
||||
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
model.maxpool = nn.Identity()
|
||||
model.fc = nn.Linear(512, 10)
|
||||
return model
|
||||
|
||||
# VGG-13 for CIFAR-10 (adapted from torchvision)
|
||||
def get_vgg13():
|
||||
model = models.vgg13(weights=None)
|
||||
model.classifier[6] = nn.Linear(4096, 10)
|
||||
return model
|
||||
|
||||
def train_epoch(model, loader, optimizer, criterion, device):
|
||||
model.train()
|
||||
correct = 0
|
||||
total = 0
|
||||
for data, target in loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
pred = output.argmax(dim=1)
|
||||
correct += pred.eq(target).sum().item()
|
||||
total += target.size(0)
|
||||
return 100. * correct / total
|
||||
|
||||
def test(model, loader, device):
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for data, target in loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
pred = output.argmax(dim=1)
|
||||
correct += pred.eq(target).sum().item()
|
||||
total += target.size(0)
|
||||
return 100. * correct / total
|
||||
|
||||
def create_optimizer(name, params, config):
|
||||
if name == 'SGD':
|
||||
return torch.optim.SGD(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0))
|
||||
elif name == 'SGDM':
|
||||
return torch.optim.SGD(params, lr=config['lr'], momentum=config['momentum'], weight_decay=config.get('weight_decay', 0))
|
||||
elif name == 'Adagrad':
|
||||
return torch.optim.Adagrad(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0))
|
||||
elif name == 'Adam':
|
||||
return torch.optim.Adam(params, lr=config['lr'], betas=config['betas'], weight_decay=config.get('weight_decay', 0))
|
||||
elif name == 'AdamW':
|
||||
return torch.optim.AdamW(params, lr=config['lr'], betas=config['betas'], weight_decay=config['weight_decay'])
|
||||
elif name == 'DMAdam':
|
||||
return DMAdam(params, lr=config['lr'], eta0=config['eta0'], eps=config.get('eps', 1e-8))
|
||||
|
||||
def run_cifar10(model_name='ResNet-34'):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
])
|
||||
|
||||
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
|
||||
test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test)
|
||||
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
|
||||
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
|
||||
|
||||
# DMAdam parameters from Table 2
|
||||
# VGG-13: η_k=0.9, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005
|
||||
# ResNet-34: η_k=1.4, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005
|
||||
dmadam_eta0 = 0.9 if model_name == 'VGG-13' else 1.4
|
||||
|
||||
optimizers_config = {
|
||||
'SGD': {'lr': 0.1, 'weight_decay': 0.0005},
|
||||
'SGDM': {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005},
|
||||
'Adagrad': {'lr': 0.01, 'weight_decay': 0.0005},
|
||||
'Adam': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005},
|
||||
'AdamW': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005},
|
||||
'DMAdam': {'lr': 0.001, 'eta0': dmadam_eta0, 'eps': 1e-8}
|
||||
}
|
||||
|
||||
results = {}
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
for opt_name, config in optimizers_config.items():
|
||||
print(f"\nTraining CIFAR-10 ({model_name}) with {opt_name}")
|
||||
|
||||
if model_name == 'ResNet-34':
|
||||
model = get_resnet34().to(device)
|
||||
elif model_name == 'VGG-13':
|
||||
model = get_vgg13().to(device)
|
||||
|
||||
optimizer = create_optimizer(opt_name, model.parameters(), config)
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
|
||||
|
||||
test_accs = []
|
||||
|
||||
for epoch in range(200):
|
||||
train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
|
||||
test_acc = test(model, test_loader, device)
|
||||
test_accs.append(test_acc)
|
||||
scheduler.step()
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/200: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")
|
||||
|
||||
results[opt_name] = test_accs
|
||||
print(f"{opt_name} Final Test Acc: {test_accs[-1]:.2f}%")
|
||||
|
||||
plt.figure(figsize=(10, 5))
|
||||
for name, accs in results.items():
|
||||
plt.plot(range(1, 201), accs, label=name)
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Test Accuracy (%)')
|
||||
plt.title(f'CIFAR-10 Test Accuracy ({model_name})')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
os.makedirs('results', exist_ok=True)
|
||||
plt.savefig(f'results/cifar10_{model_name.lower().replace("-", "")}_accuracy.png', dpi=140, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
return results
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Running CIFAR-10 experiment with ResNet-34...")
|
||||
run_cifar10('ResNet-34')
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("Running CIFAR-10 experiment with VGG-13...")
|
||||
run_cifar10('VGG-13')
|
||||
Reference in New Issue
Block a user