import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets, transforms, models import matplotlib.pyplot as plt from dmadam import DMAdam import os # ResNet-34 for CIFAR-10 (adapted from torchvision) def get_resnet34(): model = models.resnet34(weights=None) model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) model.maxpool = nn.Identity() model.fc = nn.Linear(512, 10) return model # VGG-13 for CIFAR-10 (adapted from torchvision) def get_vgg13(): model = models.vgg13(weights=None) model.classifier[6] = nn.Linear(4096, 10) return model def train_epoch(model, loader, optimizer, criterion, device): model.train() correct = 0 total = 0 for data, target in loader: data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() total += target.size(0) return 100. * correct / total def test(model, loader, device): model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in loader: data, target = data.to(device), target.to(device) output = model(data) pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() total += target.size(0) return 100. * correct / total def create_optimizer(name, params, config): if name == 'SGD': return torch.optim.SGD(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0)) elif name == 'SGDM': return torch.optim.SGD(params, lr=config['lr'], momentum=config['momentum'], weight_decay=config.get('weight_decay', 0)) elif name == 'Adagrad': return torch.optim.Adagrad(params, lr=config['lr'], weight_decay=config.get('weight_decay', 0)) elif name == 'Adam': return torch.optim.Adam(params, lr=config['lr'], betas=config['betas'], weight_decay=config.get('weight_decay', 0)) elif name == 'AdamW': return torch.optim.AdamW(params, lr=config['lr'], betas=config['betas'], weight_decay=config['weight_decay']) elif name == 'DMAdam': return DMAdam(params, lr=config['lr'], eta0=config['eta0'], eps=config.get('eps', 1e-8)) def run_cifar10(model_name='ResNet-34'): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train) test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2) test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2) # DMAdam parameters from Table 2 # VGG-13: η_k=0.9, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005 # ResNet-34: η_k=1.4, α_k=0.001, β_k=0.999, ε=10^-8, weight_decay=0.0005 dmadam_eta0 = 0.9 if model_name == 'VGG-13' else 1.4 optimizers_config = { 'SGD': {'lr': 0.001, 'weight_decay': 0.0005}, 'SGDM': {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.0005}, 'Adagrad': {'lr': 0.001, 'weight_decay': 0.0005}, 'Adam': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005}, 'AdamW': {'lr': 0.001, 'betas': (0.9, 0.999), 'weight_decay': 0.0005}, 'DMAdam': {'lr': 0.001, 'eta0': dmadam_eta0, 'eps': 1e-8} } results = {} criterion = nn.CrossEntropyLoss() for opt_name, config in optimizers_config.items(): print(f"\nTraining CIFAR-10 ({model_name}) with {opt_name}") if model_name == 'ResNet-34': model = get_resnet34().to(device) elif model_name == 'VGG-13': model = get_vgg13().to(device) optimizer = create_optimizer(opt_name, model.parameters(), config) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1) test_accs = [] for epoch in range(200): train_acc = train_epoch(model, train_loader, optimizer, criterion, device) test_acc = test(model, test_loader, device) test_accs.append(test_acc) scheduler.step() if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1}/200: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%") results[opt_name] = test_accs print(f"{opt_name} Final Test Acc: {test_accs[-1]:.2f}%") plt.figure(figsize=(10, 5)) for name, accs in results.items(): plt.plot(range(1, 201), accs, label=name) plt.xlabel('Epoch') plt.ylabel('Test Accuracy (%)') plt.title(f'CIFAR-10 Test Accuracy ({model_name})') plt.legend() plt.grid(True) os.makedirs('results', exist_ok=True) plt.savefig(f'results/cifar10_{model_name.lower().replace("-", "")}_accuracy.png', dpi=140, bbox_inches='tight') plt.close() return results if __name__ == '__main__': print("Running CIFAR-10 experiment with ResNet-34...") run_cifar10('ResNet-34') print("\n" + "="*50) print("Running CIFAR-10 experiment with VGG-13...") run_cifar10('VGG-13')