update: experiments on CIFAR10

This commit is contained in:
2026-02-05 15:01:27 +08:00
parent 9966697a7d
commit 06cc14c47a
10 changed files with 317 additions and 233 deletions

149
dl-main.py Normal file
View File

@@ -0,0 +1,149 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from dmadam import DMAdam
import os
# ResNet-34 for CIFAR-10 (adapted from torchvision)
def get_resnet34():
model = models.resnet34(weights=None)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(512, 10)
return model
# VGG-13 for CIFAR-10 (adapted from torchvision)
def get_vgg13():
model = models.vgg13(weights=None)
model.classifier[6] = nn.Linear(4096, 10)
return model
def train_epoch(model, loader, optimizer, criterion, device):
model.train()
correct = 0
total = 0
for data, target in loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
return 100. * correct / total
def test(model, loader, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
return 100. * correct / total
def create_optimizer(name, params, config):
if name == 'SGD':
return torch.optim.SGD(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0))
elif name == 'SGDM':
return torch.optim.SGD(params, lr=config['lr'], momentum=config['momentum'], weight_decay=config.get('weight_decay', 0))
elif name == 'Adagrad':
return torch.optim.Adagrad(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0))
elif name == 'Adam':
return torch.optim.Adam(params, lr=config['lr'], betas=config['betas'], weight_decay=config.get('weight_decay', 0))
elif name == 'AdamW':
return torch.optim.AdamW(params, lr=config['lr'], betas=config['betas'], weight_decay=config['weight_decay'])
elif name == 'DMAdam':
return DMAdam(params, lr=config['lr'], eta0=config['eta0'], eps=config.get('eps', 1e-8))
def run_cifar10(model_name='ResNet-34'):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
# DMAdam parameters from Table 2
# VGG-13: η_k=0.9, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005
# ResNet-34: η_k=1.4, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005
dmadam_eta0 = 0.9 if model_name == 'VGG-13' else 1.4
optimizers_config = {
'SGD': {'lr': 0.1, 'weight_decay': 0.0005},
'SGDM': {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0005},
'Adagrad': {'lr': 0.01, 'weight_decay': 0.0005},
'Adam': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005},
'AdamW': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005},
'DMAdam': {'lr': 0.001, 'eta0': dmadam_eta0, 'eps': 1e-8}
}
results = {}
criterion = nn.CrossEntropyLoss()
for opt_name, config in optimizers_config.items():
print(f"\nTraining CIFAR-10 ({model_name}) with {opt_name}")
if model_name == 'ResNet-34':
model = get_resnet34().to(device)
elif model_name == 'VGG-13':
model = get_vgg13().to(device)
optimizer = create_optimizer(opt_name, model.parameters(), config)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
test_accs = []
for epoch in range(200):
train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
test_acc = test(model, test_loader, device)
test_accs.append(test_acc)
scheduler.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/200: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")
results[opt_name] = test_accs
print(f"{opt_name} Final Test Acc: {test_accs[-1]:.2f}%")
plt.figure(figsize=(10, 5))
for name, accs in results.items():
plt.plot(range(1, 201), accs, label=name)
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy (%)')
plt.title(f'CIFAR-10 Test Accuracy ({model_name})')
plt.legend()
plt.grid(True)
os.makedirs('results', exist_ok=True)
plt.savefig(f'results/cifar10_{model_name.lower().replace("-", "")}_accuracy.png', dpi=140, bbox_inches='tight')
plt.close()
return results
if __name__ == '__main__':
print("Running CIFAR-10 experiment with ResNet-34...")
run_cifar10('ResNet-34')
print("\n" + "="*50)
print("Running CIFAR-10 experiment with VGG-13...")
run_cifar10('VGG-13')