Skip to content

Commit

Permalink
Peteish curve fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
liujch1998 committed Nov 5, 2024
1 parent 7b71d42 commit a40ed1e
Show file tree
Hide file tree
Showing 14 changed files with 731 additions and 121 deletions.
56 changes: 43 additions & 13 deletions olmo/scaling/scaling_laws/download_wandb_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def get_runs(run_paths: List) -> List:
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--wandb-names", type=str, nargs="+", required=True, help="Full run name or regex")
parser.add_argument("-x", "--x-axis", type=str, default="throughput/total_tokens", help="X axis")
parser.add_argument("-x", "--x-axis", type=str, default="_step", help="X axis")
parser.add_argument("-y", "--y-axis", nargs="+", type=str, default=["train/Perplexity"], help="Y axis")
parser.add_argument("-e", "--eval-only", action="store_true")
parser.add_argument(
"-o",
"--output-path",
Expand Down Expand Up @@ -96,6 +97,13 @@ def main(args):
+ [f"eval/downstream/{d}" for d in downstream_newline]
)

if not args.eval_only:
args.y_axis += [
"throughput/total_tokens",
"throughput/total_training_Gflops",
"optim/learning_rate_group0",
]

wb_runs = get_runs(args.wandb_names)

print("Downloading the data from the following wandb runs:\n", "\n".join([str(run) for run in wb_runs]))
Expand All @@ -104,23 +112,14 @@ def main(args):
if dirname:
os.makedirs(dirname, exist_ok=True)
with open(args.output_path, "w") as file_ref:
writer = csv.DictWriter(
file_ref,
fieldnames=[args.x_axis]
+ ["throughput/total_training_Gflops"]
+ args.y_axis
+ ["optim/learning_rate_group0", "learning_rate_peak", "batch_size_in_tokens"],
)
writer = csv.DictWriter(file_ref, fieldnames=[args.x_axis] + args.y_axis + ["learning_rate_peak", "batch_size_in_tokens"])
writer.writeheader()

rows = []
for wb_run in tqdm(wb_runs):
print(f"Processing {wb_run.name}")
history = wb_run.scan_history(
keys=[args.x_axis]
+ ["throughput/total_training_Gflops"]
+ args.y_axis
+ ["optim/learning_rate_group0"],
keys=[args.x_axis] + args.y_axis,
page_size=10000,
) # page_size cannot be too big, it will make it faster but it will start to downsample

Expand All @@ -130,10 +129,10 @@ def main(args):
)

for wb_step in history:
rows.append(wb_step)
wb_step["learning_rate_peak"] = config["optimizer"]["value"]["learning_rate"]
# With certain run restarts, we also update the batch size.
wb_step["batch_size_in_tokens"] = batch_size_in_tokens
rows.append(wb_step)

row_by_key = {}
for row in rows:
Expand Down Expand Up @@ -246,6 +245,37 @@ def main(args):
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/amberish-rulebased-3B-2xC' -y eval/validation-and-bpb-and-downstream -o wandb/amberish-rulebased/3B-2xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/amberish-rulebased-3B-5xC' -y eval/validation-and-bpb-and-downstream -o wandb/amberish-rulebased/3B-5xC.csv

# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-medium/peteish7' -y eval/downstream/arc_easy_acc -o wandb/peteish7_train.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-medium/peteish7-eval' -y eval/validation-and-bpb-and-downstream -o wandb/peteish7_eval_final.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-medium/peteish7-eval' -y eval/validation-and-bpb-and-downstream -e -o wandb/peteish7_eval_full.csv

# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-190M-1xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/190M-1xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-370M-1xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/370M-1xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-600M-1xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/600M-1xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-760M-1xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/760M-1xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-1B-1xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/1B-1xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-190M-2xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/190M-2xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-370M-2xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/370M-2xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-600M-2xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/600M-2xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-760M-2xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/760M-2xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-1B-2xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/1B-2xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-190M-5xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/190M-5xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-370M-5xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/370M-5xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-600M-5xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/600M-5xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-760M-5xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/760M-5xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-1B-5xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/1B-5xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-190M-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/190M-10xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-370M-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/370M-10xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-600M-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/600M-10xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-760M-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/760M-10xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-final-1B-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-final/1B-10xC.csv

# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-const-190M-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-const/190M-10xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-const-370M-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-const/370M-10xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-const-600M-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-const/600M-10xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-const-760M-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-const/760M-10xC.csv
# python olmo/scaling/scaling_laws/download_wandb_logs.py -n 'ai2-llm/olmo-ladder/peteish-const-1B-10xC' -y eval/validation-and-bpb-and-downstream -o wandb/peteish-const/1B-10xC.csv

args = parse_args()
print(args)
main(args)
30 changes: 30 additions & 0 deletions olmo/scaling/scaling_laws/merge_wandb_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import csv
import sys

train_path = sys.argv[1]
eval_path = sys.argv[2]

train_row_by_step = {}
with open(train_path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
step = int(row['_step'])
train_row_by_step[step] = row

rows = []
with open(eval_path, 'r') as f:
reader = csv.DictReader(f)
fieldnames = reader.fieldnames
for row in reader:
step = int(row['_step'])
if step in train_row_by_step:
train_row = train_row_by_step[step]
train_row = {k: train_row[k] for k in ["throughput/total_tokens", "throughput/total_training_Gflops", "optim/learning_rate_group0"]}
row.update(train_row)
rows.append(row)

with open(eval_path, 'w') as f:
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
writer.writeheader()
for row in rows:
writer.writerow(row)
33 changes: 30 additions & 3 deletions olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class FinalConfig:
"all-val-lm": [f"eval/{val}/CrossEntropyLoss" for val in validation],
"all-bpb": [f"eval/downstream_bpb/{task}_bpb" for task in downstream_bpb],
"c4": ["eval/c4_en-validation/CrossEntropyLoss"],
"mmlu": [
"mmlu-var": [
f"eval/downstream_bpb/{task}_bpb"
for task in [
"mmlu_stem_var_bpb",
Expand All @@ -299,6 +299,10 @@ class FinalConfig:
"mmlu_other_var_bpb",
]
],
"mmlu-stem-var": ["eval/downstream_bpb/mmlu_stem_var_bpb_bpb"],
"mmlu-humanities-var": ["eval/downstream_bpb/mmlu_humanities_var_bpb_bpb"],
"mmlu-social-sciences-var": ["eval/downstream_bpb/mmlu_social_sciences_var_bpb_bpb"],
"mmlu-other-var": ["eval/downstream_bpb/mmlu_other_var_bpb_bpb"],
"hellaswag-5shot": ["eval/downstream_bpb/hellaswag_rc_5shot_bpb_bpb"],
"arc-e-5shot": ["eval/downstream_bpb/arc_easy_rc_5shot_bpb_bpb"],
"arc-c-5shot": ["eval/downstream_bpb/arc_challenge_rc_5shot_bpb_bpb"],
Expand All @@ -311,6 +315,12 @@ class FinalConfig:
"csqa-5shot": ["eval/downstream_bpb/csqa_rc_5shot_bpb_bpb"],
"socialiqa-5shot": ["eval/downstream_bpb/socialiqa_rc_5shot_bpb_bpb"],
}
WEIGHT_BY_KEY = {
"mmlu_stem_var_bpb": 0.215,
"mmlu_humanities_var_bpb": 0.335,
"mmlu_social_sciences_var_bpb": 0.219,
"mmlu_other_var_bpb": 0.231,
}


def parse_args():
Expand Down Expand Up @@ -364,7 +374,7 @@ def get_data_by_name(configs: Dict[str, ExtrapolateNConfig], keys: List[str], mi
last_fake_lr = fake_lr
last_d = d
encountered_ds.add(d)
y = np.mean([float(row[key]) for key in keys])
y = np.average([float(row[key]) for key in keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in keys])
if min_step is not None and d < min_step * batch_size:
continue
data_by_name[name]["ns"].append(n)
Expand All @@ -388,7 +398,7 @@ def get_final_data_by_name(configs, keys, num_to_avg=1):
ds, ys = [], []
for row in rows:
d = int(float(row["throughput/total_tokens"]))
y = np.mean([float(row[key]) for key in keys])
y = np.average([float(row[key]) for key in keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in keys])
ds.append(d)
ys.append(y)
d = np.mean(ds)
Expand Down Expand Up @@ -497,6 +507,23 @@ def grad_chinchilla_n_d_lr_fit(x, p):
return [grad_a, grad_b, grad_alpha, grad_beta, grad_E, grad_F]


# x[0] = n, x[1] = d, x[2] = h
# p[0] = a = log(A), p[1] = b = log(B), p[2] = alpha, p[3] = beta, p[4] = E, p[5] = F
def chinchilla_n_d_lr_minus_fit(x, p):
# return e**a / x[0]**alpha + e**b / x[1]**beta + E - F * (1 - x[2])
return np.exp(p[0]) / x[0] ** p[2] + np.exp(p[1]) / x[1] ** p[3] + p[4] - p[5] * (1 - x[2])


def grad_chinchilla_n_d_lr_minus_fit(x, p):
grad_a = np.exp(p[0]) / x[0] ** p[2]
grad_b = np.exp(p[1]) / x[1] ** p[3]
grad_alpha = np.exp(p[0]) * (-np.log(x[0])) / x[0] ** p[2]
grad_beta = np.exp(p[1]) * (-np.log(x[1])) / x[1] ** p[3]
grad_E = 1
grad_F = -(1 - x[2])
return [grad_a, grad_b, grad_alpha, grad_beta, grad_E, grad_F]


def chinchilla_n_d_lr_log_fit(x, p):
# return e**a / x[0]**alpha + e**b / x[1]**beta + E + F * x[2] * np.log(x[0] / e**r + s)
return (
Expand Down
1 change: 1 addition & 0 deletions scripts/ladder_peteish.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def config_from_args(args: argparse.Namespace) -> TrainConfig:
label="all-small-ppl-validation",
data=DataConfig(
drop_last=True,
memmap_dtype="uint32",
datasets={
"c4_en-validation": [
f"{read_location}/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy"
Expand Down
62 changes: 62 additions & 0 deletions scripts/scaling/final_peteish.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
{
"190m": {
"paths": [
"wandb/peteish-final/190M-1xC.csv",
"wandb/peteish-final/190M-2xC.csv",
"wandb/peteish-final/190M-5xC.csv",
"wandb/peteish-final/190M-10xC.csv"
],
"mode": "train",
"n": 190354176,
"label": "190m",
"color": "darkred"
},
"370m": {
"paths": [
"wandb/peteish-final/370M-1xC.csv",
"wandb/peteish-final/370M-2xC.csv",
"wandb/peteish-final/370M-5xC.csv",
"wandb/peteish-final/370M-10xC.csv"
],
"mode": "train",
"n": 371262464,
"label": "370m",
"color": "darkorange"
},
"600m": {
"paths": [
"wandb/peteish-final/600M-1xC.csv",
"wandb/peteish-final/600M-2xC.csv",
"wandb/peteish-final/600M-5xC.csv",
"wandb/peteish-final/600M-10xC.csv"
],
"mode": "train",
"n": 597382464,
"label": "600m",
"color": "goldenrod"
},
"760m": {
"paths": [
"wandb/peteish-final/760M-1xC.csv",
"wandb/peteish-final/760M-2xC.csv",
"wandb/peteish-final/760M-5xC.csv",
"wandb/peteish-final/760M-10xC.csv"
],
"mode": "train",
"n": 758220288,
"label": "760m",
"color": "darkgreen"
},
"1b": {
"paths": [
"wandb/peteish-final/1B-1xC.csv",
"wandb/peteish-final/1B-2xC.csv",
"wandb/peteish-final/1B-5xC.csv",
"wandb/peteish-final/1B-10xC.csv"
],
"mode": "train",
"n": 1279395840,
"label": "1b",
"color": "teal"
}
}
12 changes: 12 additions & 0 deletions scripts/scaling/final_peteish.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
python scripts/scaling/final.py -k mmlu-var -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_mmlu-var.png
python scripts/scaling/final.py -k hellaswag-5shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_hellaswag-5shot.png
python scripts/scaling/final.py -k arc-e-5shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_arc-e-5shot.png
python scripts/scaling/final.py -k arc-c-5shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_arc-c-5shot.png
python scripts/scaling/final.py -k piqa-5shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_piqa-5shot.png
python scripts/scaling/final.py -k winogrande-5shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_winogrande-5shot.png
python scripts/scaling/final.py -k openbookqa-5shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_openbookqa-5shot.png
python scripts/scaling/final.py -k boolq-5shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_boolq-5shot.png
python scripts/scaling/final.py -k sciq-0shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_sciq-0shot.png
python scripts/scaling/final.py -k copa-0shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_copa-0shot.png
python scripts/scaling/final.py -k csqa-5shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_csqa-5shot.png
python scripts/scaling/final.py -k socialiqa-5shot -c scripts/scaling/final_peteish.json -o figure/peteish-final/final_socialiqa-5shot.png
39 changes: 17 additions & 22 deletions scripts/scaling/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def main():
configs = json.load(f)
configs = {name: ExtrapolateNConfig(**config) for name, config in configs.items()}

data_by_name = get_data_by_name(configs, args.keys, min_step=3000)
data_by_name = get_data_by_name(configs, args.keys, min_step=5000)

sns.set_style("whitegrid")

num_axs = 5
fig, axs = plt.subplots(1, num_axs, figsize=(num_axs * 6, 4.5))
fig, axs = plt.subplots(1, num_axs, figsize=(num_axs * 4, 3))

train_ndhs, train_ys = [], []
for name, data in data_by_name.items():
Expand All @@ -42,7 +42,7 @@ def main():
train_ys,
chinchilla_n_d_fit,
grad_chinchilla_n_d_fit,
p0=[4.0, 15.0, 0.25, 0.7, 1.5],
p0=[4.0, 4.0, 0.3, 0.3, 0.5],
bounds=[(0, None), (0, None), (0, None), (0, None), (0, None)],
)
a, b, alpha, beta, E = coefficients
Expand All @@ -63,7 +63,7 @@ def main():
config = configs[name]
ax = axs[get_ax(name)]
ax.scatter(
data["ds"], data["ys"], color="white", edgecolors=config.color, label=config.label, s=5, alpha=0.4
data["ds"], data["ys"], color="white", edgecolors=config.color, label=config.label, s=5, alpha=0.25
)

# plot the fitted curve
Expand All @@ -89,35 +89,30 @@ def main():
all_rel_errors += rel_errors
rel_error = np.mean(rel_errors)
ax.annotate(
f"err: {rel_error:.2%}",
f"{rel_error:.2%}",
xy=(data["ds"][-1], pred_data["ys"][-1]),
xycoords="data",
xytext=(-10, 8),
textcoords="offset points",
fontsize=9,
color=config.color,
)
axs[3].annotate(
f"L(N, D) = {A:.2f} / N^{alpha:.2f} + {B:.2f} / D^{beta:.2f} + {E:.2f}\nAvg err: {np.mean(all_rel_errors):.2%}",
xy=(0.15, 0.55),
xycoords="axes fraction",
fontsize=9,
)
plt.text(
x=0.40,
y=0.90,
s=f"L(n, d) = {A:.2f} / n^{alpha:.2f} + {B:.2f} / d^{beta:.2f} + {E:.2f}",
fontsize=12,
transform=fig.transFigure,
)
# axs[3].annotate(
# f"L(N, D) = {A:.2f} / N^{alpha:.2f} + {B:.2f} / D^{beta:.2f} + {E:.2f}\nAvg err: {np.mean(all_rel_errors):.2%}",
# xy=(0.15, 0.55),
# xycoords="axes fraction",
# fontsize=7,
# )

for ax in axs:
ax.legend(loc="upper right", ncols=2, fontsize=8)
ax.legend(loc="upper right", ncols=2, fontsize=7)
ax.set_xlabel("Tokens (D)")
axs[0].set_ylabel(f"CE loss, {args.key if args.key != '' else args.keys}")
axs[3].set_ylabel("Loss")
axs[3].set_title(args.key)
plt.suptitle("Fitting loss curves")
axs[3].set_title(args.key, fontsize=10)
plt.suptitle(
f"{args.key}\nL(N, D, H) = {A:.2f} / N^{alpha:.2f} + {B:.2f} / D^{beta:.2f} + {E:.2f}",
fontsize=8,
)
plt.savefig(args.output_path, dpi=300, bbox_inches="tight")


Expand Down
Loading

0 comments on commit a40ed1e

Please sign in to comment.