Skip to content

Commit

Permalink
more plotting machinery
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed May 22, 2024
1 parent be84ffa commit 5c0397b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 22 deletions.
54 changes: 37 additions & 17 deletions plot/plot_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from matplotlib import rc
import matplotlib.pyplot as plt
from sympy import plot


rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
Expand All @@ -17,8 +16,10 @@
)

PROBLEMS = [
# "MNIST_CNN7_b=1024",
"MNIST_CNN7_b=128",
"MNIST_CNN7_b=1024",
# "FashionMNIST_CNN5_b=128",
# "MNIST_AlgoPerf_b=128",
# "MNIST_CNN7_b=128",
# "MNIST_CNN3_b=128",
]

Expand Down Expand Up @@ -74,31 +75,42 @@ def extract_metrics(problem_dir, plot_filter):

# fmt: off
PLOT_FILTER = {
"RFD(SE)-Conservative": {
"includes": ["RFD", "SquaredExponential", "conservatism"],
"excludes": ["b_size_inv"],
},
"S-RFD(SE)-Conservative": {
"includes": ["RFD", "SquaredExponential", "b_size_inv", "conservatism"],
},
"RFD(SE)": {
"includes": ["RFD", "SquaredExponential"],
"excludes": ["b_size_inv"],
"excludes": ["b_size_inv", "conservatism"],
},
"S-RFD(SE)": {
"includes": ["RFD", "SquaredExponential", "b_size_inv"],
"excludes": ["conservatism"],
},
"RFD(RQ(beta=1))": {
"includes": ["RFD", "RationalQuadratic", "beta=1"],
},
"A-RFD": {
"includes": ["SGD", "lr=14.2"],
},
"Adam(lr=1e-2)": {
"includes": ["Adam", "lr=0.01"],
},
"SGD(lr=1)": {
"includes": ["SGD", "lr=1."],
},
"SGD(lr=0.1)": {
"includes": ["SGD", "lr=0.1"],
},
"Adam(lr=1e-2)": {
"includes": ["Adam", "lr=0.01"],
},
"Adam(lr=1e-3)": {
"includes": ["Adam", "lr=0.001"],
},
"Adam(lr=1e-4)": {
"includes": ["Adam", "lr=0.0001"],
},
}
# fmt: on
def plot_filter(wanted="all"):
Expand Down Expand Up @@ -239,17 +251,25 @@ def spread_lr(df):
return df


SummaryKeys = {
"MNIST_CNN7_b=1024": ["SGD(lr=1)", "Adam(lr=1e-2)"],
"FashionMNIST_CNN5_b=128": ["SGD(lr=0.1)", "Adam(lr=1e-4)", "Adam(lr=1e-3)"],
"MNIST_CNN7_b=128": ["SGD(lr=0.1)", "Adam(lr=1e-3)"],
"MNIST_CNN3_b=128": ["SGD(lr=0.1)", "Adam(lr=1e-3)"],
"MNIST_AlgoPerf_b=128": ["SGD(lr=1)", "Adam(lr=1e-3)"],
}

def plot_summary(problem):
problem_dir = Path(f"logs/{problem}")
wanted = ["RFD(SE)", "S-RFD(SE)", "RFD(RQ(beta=1))", "A-RFD"]
wanted = ["RFD(SE)", "S-RFD(SE)", "RFD(RQ(beta=1))", "A-RFD", "RFD(SE)-Conservative", "S-RFD(SE)-Conservative"]
metrics = extract_metrics(problem_dir, plot_filter(wanted))
(fig, axs) = plt.subplots(2, 2, figsize=(9, 7))
(fig, axs) = plt.subplots(2, 2, figsize=(9, 6))
for (idx, item) in enumerate(metrics):
plot_validation_loss(axs[0,0], item, idx=idx)
plot_initial_learning_rate(axs[1,0], item, idx=idx)
plot_step_size(axs[1,1], item, idx=idx)

metrics = extract_metrics(problem_dir, plot_filter(["Adam(lr=1e-3)", "SGD(lr=0.1)"]))
metrics = extract_metrics(problem_dir, plot_filter(SummaryKeys.get(problem, [])))
for (idx, item) in enumerate(metrics, start=len(wanted)):
plot_validation_loss(axs[0,0], item, idx=idx)
plot_step_size(axs[1,1], item, idx=idx)
Expand All @@ -260,8 +280,8 @@ def plot_summary(problem):
sgd_metrics = extract_metrics(problem_dir, {"SGD": {"includes": ["SGD"]}})
sgd_joined = pd.concat([spread_lr(x["metrics"]) for x in sgd_metrics])

plot_final_loss_over_lr(axs[0, 1], {"name": "Adam", "metrics": adam_joined}, idx=len(wanted))
plot_final_loss_over_lr(axs[0, 1], {"name": "SGD", "metrics": sgd_joined}, idx=len(wanted)+1)
plot_final_loss_over_lr(axs[0, 1], {"name": "SGD", "metrics": sgd_joined}, idx=len(wanted))
plot_final_loss_over_lr(axs[0, 1], {"name": "Adam", "metrics": adam_joined}, idx=len(wanted)+1)
axs[0,1].legend()

fig.tight_layout()
Expand All @@ -270,7 +290,7 @@ def plot_summary(problem):
def plot_performance(problem):
problem_dir = Path(f"logs/{problem}")
metrics = extract_metrics(problem_dir, PLOT_FILTER)
(fig, axs) = plt.subplots(2, 2, figsize=(9, 6))
(fig, axs) = plt.subplots(2, 2, figsize=(9, 5))
for (idx, item) in enumerate(metrics):
plot_validation_loss(axs[0, 0], item, idx=idx)
plot_train_loss(axs[1, 0], item, idx=idx)
Expand All @@ -283,8 +303,8 @@ def plot_performance(problem):
sgd_metrics = extract_metrics(problem_dir, {"SGD": {"includes": ["SGD"]}})
sgd_joined = pd.concat([spread_lr(x["metrics"]) for x in sgd_metrics])

plot_final_loss_over_lr(axs[1, 1], {"name": "Adam", "metrics": adam_joined}, idx=4)
plot_final_loss_over_lr(axs[1, 1], {"name": "SGD", "metrics": sgd_joined}, idx=5)
plot_final_loss_over_lr(axs[1, 1], {"name": "SGD", "metrics": sgd_joined}, idx=4)
plot_final_loss_over_lr(axs[1, 1], {"name": "Adam", "metrics": adam_joined}, idx=5)
axs[1,1].legend()

fig.tight_layout()
Expand Down Expand Up @@ -318,8 +338,8 @@ def plot_step_behavior(problem):
def main():
for problem in PROBLEMS:
plot_summary(problem)
plot_performance(problem)
plot_step_behavior(problem)
# plot_performance(problem)
# plot_step_behavior(problem)


if __name__ == "__main__":
Expand Down
24 changes: 19 additions & 5 deletions plot/visualize_covariance_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def visualize_fit():
fig.tight_layout()
plt.savefig("plot/MNIST_CNN7_covariance_fit.pdf")

def get_run_covariance(run):
def get_run_covariance_and_cost(run):
mnist = MNIST(batch_size=100)
mnist.prepare_data()
mnist.setup("fit")
Expand All @@ -67,18 +67,32 @@ def get_run_covariance(run):
)
cov_model.dims = sampler.dims
cov_model.fit(sampler.snapshot_as_dataframe())
return cov_model
return cov_model, sampler.sample_cost

def visualize_lr_variance():
asympt_lr = []
sample_costs = []
for run in range(20):
cov_model = get_run_covariance(run)
cov_model, sample_cost = get_run_covariance_and_cost(run)
if sample_cost > 100000:
print(f"run: {run}, sample_cost: {sample_cost}, asymptotic_lr: {cov_model.asymptotic_learning_rate()}")
asympt_lr.append(cov_model.asymptotic_learning_rate())
sample_costs.append(sample_cost)

print(asympt_lr)
plt.hist(asympt_lr, label="Asymptotic learning rate")
# print(asympt_lr)
plt.figure(figsize=(4, 3))
plt.hist(asympt_lr)
plt.xlabel("Asymptotic learning rate")
plt.tight_layout()
plt.savefig("plot/MNIST_CNN7_asymptotic_lr.pdf")

print(f"less_one_epoch: {len([cost for cost in sample_costs if cost < 60_000])}")
plt.figure(figsize=(4, 3))
plt.hist(sample_costs, label="Sample cost")
plt.xlabel("Sample cost")
plt.tight_layout()
plt.savefig("plot/MNIST_CNN7_sample_cost.pdf")


if __name__ == "__main__":
visualize_lr_variance()
Expand Down

0 comments on commit 5c0397b

Please sign in to comment.