Skip to content

Commit

Permalink
Massage for plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
liujch1998 committed Nov 13, 2024
1 parent dca97a7 commit a21992b
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 130 deletions.
4 changes: 2 additions & 2 deletions olmo/scaling/scaling_laws/fitting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def grad_tissue_fit(x, p):
return [grad_a, grad_b, grad_alpha, grad_beta, grad_E, grad_F, grad_r]


def sigmoid(x, L, x0, k, b):
o = L / (1 + np.exp(-k * (x - x0))) + b
def sigmoid(x, a, x0, k, b):
o = a / (1 + np.exp(-k * (x - x0))) + b
return o


Expand Down
14 changes: 7 additions & 7 deletions olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_accuracy_keys(self):
"socialiqa",
"winogrande",
]
core_small_names = ["hellaswag", "arc_challenge", "csqa", "openbookqa", "piqa", "socialiqa"]
core_small_names = ["hellaswag", "arc_challenge", "piqa", "csqa", "socialiqa"]
mmlu_names = ["mmlu_stem", "mmlu_humanities", "mmlu_social_sciences", "mmlu_other"]

core_5shot_tasks: Dict[str, DownstreamTaskPrediction] = {
Expand Down Expand Up @@ -186,7 +186,7 @@ def get_task_sets(keys):
elif keys[0] == "mmlu":
keys = list(mmlu_var_tasks.keys()) + list(mmlu_subset_var_tasks.keys())
elif keys[0] == "main":
keys = list(core_small_5shot_tasks.keys()) + list(mmlu_var_tasks.keys())
keys = list(mmlu_var_tasks.keys()) + list(core_small_5shot_tasks.keys())
return keys


Expand Down Expand Up @@ -310,10 +310,6 @@ def get_accuracy_keys(tasks: Dict[str, DownstreamTaskPrediction]) -> List[str]:
}

WEIGHT_BY_KEY = {
"mmlu_stem_var_bpb": 0.215,
"mmlu_humanities_var_bpb": 0.335,
"mmlu_social_sciences_var_bpb": 0.219,
"mmlu_other_var_bpb": 0.231,
"eval/downstream_bpb/mmlu_stem_var_bpb_bpb": 0.215,
"eval/downstream_bpb/mmlu_humanities_var_bpb_bpb": 0.335,
"eval/downstream_bpb/mmlu_social_sciences_var_bpb_bpb": 0.219,
Expand All @@ -322,6 +318,10 @@ def get_accuracy_keys(tasks: Dict[str, DownstreamTaskPrediction]) -> List[str]:
"eval/downstream/mmlu_humanities_var_len_norm": 0.335,
"eval/downstream/mmlu_social_sciences_var_len_norm": 0.219,
"eval/downstream/mmlu_other_var_len_norm": 0.231,
"eval/downstream/mmlu_stem_mc_5shot_len_norm": 0.215,
"eval/downstream/mmlu_humanities_mc_5shot_len_norm": 0.335,
"eval/downstream/mmlu_social_sciences_mc_5shot_len_norm": 0.219,
"eval/downstream/mmlu_other_mc_5shot_len_norm": 0.231,
}


Expand All @@ -344,7 +344,7 @@ def get_accuracy_keys(tasks: Dict[str, DownstreamTaskPrediction]) -> List[str]:

def prettify(rel_error, is_percentage=True):
if is_percentage:
return f"{rel_error * 100:+.2f}%"
return f"{rel_error * 100:+.1f}%"
else:
return f"{rel_error:.2f}"

Expand Down
9 changes: 9 additions & 0 deletions scripts/scaling/final_peteish.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,14 @@
"n": 1279395840,
"label": "1b",
"color": "teal"
},
"7b": {
"paths": [
"scripts/scaling/data/peteish-final-new/7B-28xC-anneal-new.csv"
],
"mode": "eval",
"n": 6887575552,
"label": "7b",
"color": "darkviolet"
}
}
12 changes: 0 additions & 12 deletions scripts/scaling/final_peteish.sh

This file was deleted.

102 changes: 102 additions & 0 deletions scripts/scaling/joint_lr_minus_chained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from olmo.scaling.scaling_laws.utils import (
ExtrapolateNConfig,
chinchilla_n_d_lr_minus_fit,
get_ax,
get_data_by_name,
parse_args,
)
from olmo.scaling.scaling_laws.stacked_predictions import sigmoid


def main():
args = parse_args()

with open(args.config_path) as f:
configs = json.load(f)
configs = {name: ExtrapolateNConfig(**config) for name, config in configs.items()}

data_by_name = get_data_by_name(configs, args.keys, min_step=5000)

sns.set_style("whitegrid")

num_axs = 5
fig, axs = plt.subplots(1, num_axs, figsize=(num_axs * 4, 3))

train_ndhs, train_ys = [], []
for name, data in data_by_name.items():
config = configs[name]
if config.mode == "train":
train_ndhs += [[n, d, h] for n, d, h in zip(data["ns"], data["ds"], data["hs"])]
train_ys += data["ys"]

coefficients = [3.5051796, 4.52225812, 0.25991131, 0.28089689, 0.57286154, 0.02209304]
sigmoid_coeffs = [-0.77899618, 0.75179073, 12.64004912, 1.03518459]

# make predictions
predicted_data_by_name = {}
for name, data in data_by_name.items():
config = configs[name]
predicted_data_by_name[name] = {
"ns": data["ns"],
"ds": data["ds"],
"ys": [
sigmoid(chinchilla_n_d_lr_minus_fit([n, d, h], coefficients), *sigmoid_coeffs)
for n, d, h in zip(data["ns"], data["ds"], data["hs"])
],
}

# plot the actual data
for name, data in data_by_name.items():
config = configs[name]
ax = axs[get_ax(name)]
ax.scatter(data["ds"], data["ys"], color="white", edgecolors=config.color, label=config.label, s=10, alpha=0.25)

# plot the fitted curve
for name, data in predicted_data_by_name.items():
config = configs[name]
ax = axs[get_ax(name)]
ax.plot(
data["ds"],
data["ys"],
color=config.color,
linestyle="--",
linewidth=1.5,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)

# annotate the error
for name, data in data_by_name.items():
config = configs[name]
ax = axs[get_ax(name)]
pred_data = predicted_data_by_name[name]
rel_errors = [np.abs((pred_y - y) / y) for y, pred_y in zip(data["ys"], pred_data["ys"])]
rel_error = np.mean(rel_errors)
ax.annotate(
f"{rel_error:.2%}",
xy=(data["ds"][-1], pred_data["ys"][-1]),
xycoords="data",
xytext=(-4, -12),
textcoords="offset points",
fontsize=9,
color=config.color,
)

for ax in axs:
ax.legend(loc="lower right", ncols=1, fontsize=7)
ax.set_xlabel("Tokens (D)")
axs[0].set_ylabel("Accuracy")
plt.suptitle(
f"{args.key.replace('-acc', '')}",
fontsize=10,
)
plt.savefig(args.output_path, dpi=300, bbox_inches="tight")


if __name__ == "__main__":
main()
87 changes: 42 additions & 45 deletions scripts/scaling/final.py → scripts/scaling/step1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
tasks,
)

MARKERS = ["s", "P", "p", "*"]

MARKERS = ["s", "P", "p", "*", "o"]


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-k", "--keys", nargs="+", default=[], help="For avg metrics. Use one of [all-val-lm, all-bpb]"
"-k", "--keys", nargs="+", default=[], help="Key(s) for tasks"
)
parser.add_argument(
"--num_to_avg", type=int, default=1, help="Number of final ckpts to average (for final loss fitting)"
Expand All @@ -47,16 +48,12 @@ def fit_step1(data_by_name, is_accuracy: bool = False):
train_nds += [[n, d] for n, d in zip(data["ns"], data["ds"])]
train_ys += data["ys"]

# fit the parameters

if is_accuracy:
p0 = [1.0, 1.0, -0.01, -0.5, 0.1]
coefficients = get_coefficients(train_nds, train_ys, chinchilla_n_d_fit_e, p0=p0, disp=False)

else:
p0 = [3.0, 6.0, 0.1, 0.2, 1.0]
bounds = [(0, None), (0, None), (0, None), (None, None), (None, None)]

coefficients = get_coefficients_huber(
train_nds,
train_ys,
Expand All @@ -75,12 +72,15 @@ def predict_step1(data_by_name, coefficients):
predicted_data_by_name = {}
plotted_predicted_data_by_name = {}

dmin = 0.8 * min([min(data["ds"]) for data in data_by_name.values()])
dmax = 1.2 * max([max(data["ds"]) for data in data_by_name.values()])

for name, data in data_by_name.items():
predicted_data_by_name[name] = {
"ds": data["ds"],
"ys": [chinchilla_n_d_fit([n, d], coefficients) for n, d in zip(data["ns"], data["ds"])],
}
ds = np.linspace(min(data["ds"]), max(data["ds"]), 100)
ds = np.exp(np.linspace(np.log(dmin), np.log(dmax), 100))
ns = [data["ns"][0]] * len(ds)
plotted_predicted_data_by_name[name] = {
"ds": ds,
Expand Down Expand Up @@ -111,24 +111,37 @@ def plot_step1(
is_accuracy=False,
ax=plt.gca(),
):
# plot the actual data
# plot the actual and predicted data
unsigned_rel_errors = []
for name, data in data_by_name.items():
config = configs[name]
predicted_data = predicted_data_by_name[name]

for i, (d, y) in enumerate(zip(data["ds"], data["ys"])):
ax.scatter(d, y, color=config.color, marker=MARKERS[i], s=50)
ax.scatter(
d,
y,
color=config.color,
marker=MARKERS[i] if config.mode == "train" else "o",
s=50,
)

predicted_data = predicted_data_by_name[name]
for d, y, y_pred in zip(data["ds"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y
ax.annotate(
f"{prettify(rel_error)}",
(d, y),
textcoords="offset points",
xytext=(6, 6),
ha="center",
fontsize=8,
color=config.color,
)
if config.mode == "train":
unsigned_rel_errors.append(np.abs(rel_error))
else:
ax.annotate(
f"{prettify(rel_error)}",
(d, y),
textcoords="offset points",
xytext=(3, 3),
ha="left",
va="bottom",
fontsize=8,
color=config.color,
)
avg_unsigned_rel_error = np.mean(unsigned_rel_errors)

# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
Expand All @@ -138,21 +151,15 @@ def plot_step1(
data["ys"],
color=config.color,
linestyle="--",
linewidth=2.0,
linewidth=1.5,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)
ax.text(
x=0.20,
y=0.25,
s=fit_str,
fontsize=10,
transform=ax.transAxes,
)

ax.legend(loc="upper right", ncols=1, fontsize=10)
ax.set_xscale("log")
ax.legend(loc="upper right", ncols=1, fontsize=8)
ax.set_xlabel("Tokens (D)")
ax.set_ylabel("Accuracy" if is_accuracy else "Loss")
ax.set_title(task_name)
ax.set_ylabel("Task accuracy" if is_accuracy else "Task loss")
ax.set_title(f'{task_name}\n{fit_str}\navg unsigned rel error on fitting = {avg_unsigned_rel_error * 100:.2f}%', fontsize=9)


def main():
Expand All @@ -163,9 +170,10 @@ def main():
args.keys = get_task_sets(args.keys)

sns.set_style("whitegrid")

num_tasks = len(args.keys)
fig, axes = plt.subplots(num_tasks, 1, figsize=(6, 4.5 * num_tasks), squeeze=False)
num_cols = 3
num_rows = (num_tasks + num_cols - 1) // num_cols
fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.75 * num_cols, 3.25 * num_rows), squeeze=False)

results = "Task Name | Actual Value | Predicted Value | Relative Error"

Expand All @@ -174,14 +182,13 @@ def main():
keys = task.get_accuracy_keys() if args.accuracy else task.get_loss_keys()
data_by_name = get_final_data_by_name(configs, keys, num_to_avg=args.num_to_avg)

# fit the parameters
coefficients = fit_step1(data_by_name, args.accuracy)

# make predictions

predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_step1(
data_by_name, coefficients
)

results += f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} | {prettify(rel_error)}"

plot_step1(
Expand All @@ -192,24 +199,14 @@ def main():
task_name,
str_chinchilla_n_d_fit(coefficients),
args.accuracy,
axes[i][0],
axes[i // num_cols][i % num_cols],
)

fig.tight_layout()
fig.subplots_adjust(top=0.95)
fig.savefig(args.output_path, dpi=300)

print(results)

# y_1b_3T = chinchilla_n_d_fit([1176832000, 3e12], coefficients)
# print(f"Predicted final loss for 1b-3T: {y_1b_3T:.3f}")
# y_7b_2T = chinchilla_n_d_fit([6682316800, 2e12], coefficients)
# print(f"Predicted final loss for 7b-2T: {y_7b_2T:.3f}")
# y_7b_3T = chinchilla_n_d_fit([6682316800, 3e12], coefficients)
# print(f"Predicted final loss for 7b-3T: {y_7b_3T:.3f}")
# y_13b_5T = chinchilla_n_d_fit([13e9, 5e12], coefficients)
# print(f"Predicted final loss for 13b-5T: {y_13b_5T:.3f}")


if __name__ == "__main__":
main()
Loading

0 comments on commit a21992b

Please sign in to comment.