173 lines
6.3 KiB
Python
173 lines
6.3 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib import cm
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
|
|
from dmadam import DMAdam
|
|
|
|
plt.rcParams['figure.dpi'] = 140
|
|
|
|
def get_function_val(name, x, y):
|
|
if name == "Sphere": return x**2 + y**2
|
|
elif name == "Booth": return (x + 2*y - 7)**2 + (2*x + y - 5)**2
|
|
elif name == "Matyas": return 0.26*(x**2 + y**2) - 0.48*x*y
|
|
elif name == "Beale":
|
|
return (1.5 - x + x*y)**2 + (2.25 - x + x*y**2)**2 + (2.625 - x + x*y**3)**2
|
|
elif name == "Goldstein-Price":
|
|
term1 = 1 + (x + y + 1)**2 * (19 - 14*x + 3*x**2 - 14*y + 6*x*y + 3*y**2)
|
|
term2 = 30 + (2*x - 3*y)**2 * (18 - 32*x + 12*x**2 + 48*y - 36*x*y + 27*y**2)
|
|
return term1 * term2
|
|
elif name == "Bukin":
|
|
term1 = 100 * (abs(y - 0.01 * x ** 2)) ** 0.5
|
|
term2 = 0.01 * abs(x + 10)
|
|
return term1 + term2
|
|
return 0
|
|
|
|
def get_torch_func(name):
|
|
if name == "Sphere": return lambda x: torch.sum(x ** 2)
|
|
elif name == "Booth": return lambda x: (x[0] + 2*x[1] - 7)**2 + (2*x[0] + x[1] - 5)**2
|
|
elif name == "Matyas": return lambda x: 0.26*(x[0]**2 + x[1]**2) - 0.48*x[0]*x[1]
|
|
elif name == "Beale":
|
|
return lambda x: (1.5 - x[0] + x[0]*x[1])**2 + (2.25 - x[0] + x[0]*x[1]**2)**2 + (2.625 - x[0] + x[0]*x[1]**3)**2
|
|
elif name == "Goldstein-Price":
|
|
def goldstein(x):
|
|
x1, x2 = x[0], x[1]
|
|
t1 = 1 + (x1 + x2 + 1)**2 * (19 - 14*x1 + 3*x1**2 - 14*x2 + 6*x1*x2 + 3*x2**2)
|
|
t2 = 30 + (2*x1 - 3*x2)**2 * (18 - 32*x1 + 12*x1**2 + 48*x2 - 36*x1*x2 + 27*x2**2)
|
|
return t1 * t2
|
|
return goldstein
|
|
elif name == "Bukin":
|
|
def bukin(x):
|
|
x1, x2 = x[0], x[1]
|
|
t1 = 100 * torch.sqrt(torch.abs(x2 - 0.01 * x1 ** 2))
|
|
t2 = 0.01 * torch.abs(x1 + 10)
|
|
return t1 + t2
|
|
return bukin
|
|
return None
|
|
|
|
def run_trajectory(optimizer_name, func_name, start_point, iterations, lr, eta0):
|
|
func = get_torch_func(func_name)
|
|
params = torch.tensor(start_point, requires_grad=True, dtype=torch.float32)
|
|
|
|
if optimizer_name == 'DMAdam':
|
|
opt = DMAdam([params], lr=lr, eta0=eta0)
|
|
elif optimizer_name == 'Adam':
|
|
opt = torch.optim.Adam([params], lr=lr)
|
|
elif optimizer_name == 'SGD':
|
|
opt = torch.optim.SGD([params], lr=lr)
|
|
elif optimizer_name == 'SGDM':
|
|
opt = torch.optim.SGD([params], lr=lr, momentum=0.9)
|
|
elif optimizer_name == 'Adagrad':
|
|
opt = torch.optim.Adagrad([params], lr=lr)
|
|
|
|
path = []
|
|
with torch.no_grad():
|
|
path.append([start_point[0], start_point[1], func(params).item()])
|
|
|
|
for _ in range(iterations):
|
|
opt.zero_grad()
|
|
loss = func(params)
|
|
loss.backward()
|
|
opt.step()
|
|
p_val = params.detach().cpu().numpy()
|
|
path.append([p_val[0], p_val[1], loss.item()])
|
|
|
|
return np.array(path)
|
|
|
|
def plot_3d_surface(ax, func_name, x_range, y_range, trajectories, global_min, view_angle):
|
|
x = np.linspace(x_range[0], x_range[1], 150)
|
|
y = np.linspace(y_range[0], y_range[1], 150)
|
|
X, Y = np.meshgrid(x, y)
|
|
Z = get_function_val(func_name, X, Y)
|
|
|
|
z_min_val = np.min(Z)
|
|
|
|
if func_name == "Goldstein-Price":
|
|
clip_max = np.max(Z)
|
|
elif func_name == "Beale":
|
|
clip_max = 4000
|
|
elif func_name == "Booth":
|
|
clip_max = 800
|
|
else:
|
|
clip_max = np.max(Z)
|
|
|
|
Z_plot = np.clip(Z, z_min_val, clip_max)
|
|
|
|
surf = ax.plot_surface(X, Y, Z_plot, cmap=cm.viridis,
|
|
alpha=0.7, linewidth=0, antialiased=True)
|
|
|
|
colors = {'SGD': 'red', 'Adagrad': 'orange', 'SGDM': 'green', 'Adam': 'blue', 'DMAdam': 'black'}
|
|
styles = {'SGD': ':', 'Adagrad': ':', 'SGDM': '--', 'Adam': '--', 'DMAdam': '-'}
|
|
|
|
for name, path in trajectories.items():
|
|
path_z = np.clip(path[:, 2], z_min_val, clip_max)
|
|
|
|
ax.plot(path[:, 0], path[:, 1], path_z,
|
|
color=colors[name], label=name,
|
|
linewidth=2.5, linestyle=styles[name], zorder=10)
|
|
|
|
ax.scatter(path[-1, 0], path[-1, 1], path_z[-1], c=colors[name], marker='o', s=50, zorder=11)
|
|
|
|
true_z = get_function_val(func_name, np.array(global_min[0]), np.array(global_min[1]))
|
|
ax.scatter(global_min[0], global_min[1], max(true_z, z_min_val),
|
|
c='red', marker='*', s=300, label='Global Min', zorder=20, edgecolors='white')
|
|
|
|
ax.set_title(func_name, fontsize=14, pad=0)
|
|
|
|
ax.view_init(elev=view_angle[0], azim=view_angle[1])
|
|
ax.set_xlim(x_range)
|
|
ax.set_ylim(y_range)
|
|
ax.set_zlim(z_min_val, clip_max)
|
|
ax.set_zticks([])
|
|
|
|
def main():
|
|
iterations = 2000
|
|
lr = 0.1
|
|
lr_sgd = 0.01
|
|
lr_sgdm = 0.01
|
|
lr_gp = 0.1
|
|
lr_bukin = 0.1
|
|
|
|
eta0_normal = 3
|
|
eta0_gp = 8
|
|
eta0_bukin = 15
|
|
|
|
benchmarks = [
|
|
("Sphere", [-3.0, 4.0], (0, 0), (-4, 4), (-4, 4), (60, -45), eta0_normal),
|
|
("Booth", [-8.0, 8.0], (1, 3), (-10, 5), (-2, 10), (50, -140), eta0_normal),
|
|
("Matyas", [4.0, -4.0], (0, 0), (-6, 6), (-6, 6), (40, 45), eta0_normal),
|
|
("Beale", [2.0, 2.0], (3, 0.5), (-4, 4), (-4, 4), (50, -145), eta0_normal),
|
|
("Goldstein-Price", [-0.5, 1.5], (0, -1), (-2, 2), (-2, 2), (30, -60), eta0_gp),
|
|
# ("Bukin", [-13, 2.5], (-10, 1), (-15, -3), (-3, 6), (35, -50), eta0_bukin)
|
|
]
|
|
|
|
os.makedirs('results', exist_ok=True)
|
|
|
|
for name, start_pt, glob_min, x_rng, y_rng, view, eta0 in benchmarks:
|
|
print(f"Running {name} (Iter: {iterations}, Eta0: {eta0})...")
|
|
|
|
cur_lr = lr_gp if name == "Goldstein-Price" else lr
|
|
|
|
trajectories = {
|
|
'SGD': run_trajectory('SGD', name, start_pt, iterations, lr_sgd, eta0),
|
|
'Adagrad': run_trajectory('Adagrad', name, start_pt, iterations, cur_lr, eta0),
|
|
'SGDM': run_trajectory('SGDM', name, start_pt, iterations, lr_sgdm, eta0),
|
|
'Adam': run_trajectory('Adam', name, start_pt, iterations, cur_lr, eta0),
|
|
'DMAdam': run_trajectory('DMAdam', name, start_pt, iterations, cur_lr, eta0),
|
|
}
|
|
|
|
fig = plt.figure(figsize=(10, 8))
|
|
ax = fig.add_subplot(111, projection='3d')
|
|
plot_3d_surface(ax, name, x_rng, y_rng, trajectories, glob_min, view)
|
|
ax.legend(loc='upper right', prop={'size': 9})
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(f'results/{name}.png', dpi=140, bbox_inches='tight')
|
|
plt.close(fig)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|