From 1233662987f360822d8b488e19cd9965907e91c7 Mon Sep 17 00:00:00 2001
From: Jiacheng Liu <liujch1998@gmail.com>
Date: Thu, 21 Nov 2024 00:58:40 +0000
Subject: [PATCH] Update single-step prediction

---
 .../scaling/scaling_laws/fitting_functions.py | 12 +++--
 olmo/scaling/scaling_laws/utils.py            |  2 +-
 scripts/ladder_peteish.py                     | 12 -----
 scripts/ladder_peteish.sh                     | 44 +++++++++----------
 scripts/scaling/single_step.py                | 31 ++++++-------
 5 files changed, 41 insertions(+), 60 deletions(-)

diff --git a/olmo/scaling/scaling_laws/fitting_functions.py b/olmo/scaling/scaling_laws/fitting_functions.py
index 8c4b09c70..88b47ab6d 100644
--- a/olmo/scaling/scaling_laws/fitting_functions.py
+++ b/olmo/scaling/scaling_laws/fitting_functions.py
@@ -92,27 +92,25 @@ def grad_chinchilla_n_d_fit(x, p):
 
 # x[0] = n, x[1] = d
 # p[0] = a = log(A), p[1] = b = log(B), p[2] = alpha, p[3] = beta, p[4] = E
-# p[5] = p, p[6] = L0, p[7] = k, p[8] = q
+# p[5] = p, p[6] = q
 def combined_fit(x, p):
     step1 = np.exp(p[0]) / x[0] ** p[2] + np.exp(p[1]) / x[1] ** p[3] + p[4]
-    step2 = p[5] / (1 + np.exp(-p[7] * (step1 - p[6]))) + p[8]
+    step2 = p[5] / (1 + np.exp(-step1)) + p[6]
     step2 = max(1e-6, step2)
     return step2
 
 
 def grad_combined_fit(x, p):
     step1 = np.exp(p[0]) / x[0] ** p[2] + np.exp(p[1]) / x[1] ** p[3] + p[4]
-    grad_p = 1 / (1 + np.exp(-p[7] * (step1 - p[6])))
-    grad_k = p[5] * grad_p * (1 - grad_p) * (step1 - p[6])
-    grad_L0 = p[5] * grad_p * (1 - grad_p) * (-p[7])
-    grad_step1 = p[5] * grad_p * (1 - grad_p) * p[7]
+    grad_p = 1 / (1 + np.exp(-step1))
+    grad_step1 = p[5] * grad_p * (1 - grad_p)
     grad_q = 1
     grad_a = grad_step1 * np.exp(p[0]) / x[0] ** p[2]
     grad_b = grad_step1 * np.exp(p[1]) / x[1] ** p[3]
     grad_alpha = grad_step1 * np.exp(p[0]) * (-np.log(x[0])) / x[0] ** p[2]
     grad_beta = grad_step1 * np.exp(p[1]) * (-np.log(x[1])) / x[1] ** p[3]
     grad_E = grad_step1 * 1
-    return [grad_a, grad_b, grad_alpha, grad_beta, grad_E, grad_p, grad_k, grad_L0, grad_q]
+    return [grad_a, grad_b, grad_alpha, grad_beta, grad_E, grad_p, grad_q]
 
 
 # x[0] = n, x[1] = d
diff --git a/olmo/scaling/scaling_laws/utils.py b/olmo/scaling/scaling_laws/utils.py
index 4e2ba5c0f..8b43fefdc 100644
--- a/olmo/scaling/scaling_laws/utils.py
+++ b/olmo/scaling/scaling_laws/utils.py
@@ -146,7 +146,7 @@ def get_mc_accuracy_keys(self):
     "socialiqa",
     "winogrande",
 ]
-core_small_names = ["hellaswag", "arc_easy", "arc_challenge", "piqa", "csqa", "socialiqa", "openbookqa"]
+core_small_names = ["hellaswag", "arc_challenge", "arc_easy", "piqa", "csqa", "socialiqa", "openbookqa"]
 mmlu_names = ["mmlu_stem", "mmlu_humanities", "mmlu_social_sciences", "mmlu_other"]
 
 core_5shot_tasks: Dict[str, DownstreamTaskPrediction] = {
diff --git a/scripts/ladder_peteish.py b/scripts/ladder_peteish.py
index ea6a181db..6df6e622c 100644
--- a/scripts/ladder_peteish.py
+++ b/scripts/ladder_peteish.py
@@ -452,18 +452,6 @@ def flops_for_model(model_config: Union[ModelConfig, str]) -> int:
 def flops_cmd(args: argparse.Namespace):
     cfg = config_from_args(args)
 
-    from tqdm import tqdm
-
-    from olmo.eval import build_evaluator
-    from olmo.tokenizer import Tokenizer
-
-    device = torch.device("cpu")
-    tokenizer = Tokenizer.from_train_config(cfg)
-    for eval_cfg in tqdm(cfg.evaluators):
-        evaluator = build_evaluator(cfg, eval_cfg, tokenizer, device)
-        print(evaluator)
-    exit()
-
     flops = flops_for_model(cfg.model)
     length_in_tokens = parse_length(args.length, parse_size(args.model))
     print("Expected model flops: ", round(flops * length_in_tokens / 1e18, 3), "x 10^9 GFlops")
diff --git a/scripts/ladder_peteish.sh b/scripts/ladder_peteish.sh
index 655ed492d..1c59acc12 100644
--- a/scripts/ladder_peteish.sh
+++ b/scripts/ladder_peteish.sh
@@ -22,32 +22,30 @@
 ./scripts/beaker/ladder_peteish-launch.sh 8 --model 760M --data olmoe-mix-0924 --length 10xC --name peteish --save_overwrite --device_batch_size 2 --batch_size_divisor 128
 ./scripts/beaker/ladder_peteish-launch.sh 16 --model 1B --data olmoe-mix-0924 --length 10xC --name peteish --save_overwrite --device_batch_size 1 --batch_size_divisor 128
 
+
 ./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 4 --batch_size_divisor 128 --alpha_f 1.0
 ./scripts/beaker/ladder_peteish-launch.sh 8 --model 370M --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 2 --batch_size_divisor 128 --alpha_f 1.0
 ./scripts/beaker/ladder_peteish-launch.sh 8 --model 600M --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 2 --batch_size_divisor 128 --alpha_f 1.0
 ./scripts/beaker/ladder_peteish-launch.sh 8 --model 760M --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 2 --batch_size_divisor 128 --alpha_f 1.0
 ./scripts/beaker/ladder_peteish-launch.sh 16 --model 1B --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 1 --batch_size_divisor 128 --alpha_f 1.0
 
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 370M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-# ./scripts/beaker/ladder_peteish-launch.sh 4 --model 600M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64
-
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 370M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-# ./scripts/beaker/ladder_peteish-launch.sh 4 --model 600M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64
-
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 370M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-# ./scripts/beaker/ladder_peteish-launch.sh 4 --model 600M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64
-
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 370M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-# ./scripts/beaker/ladder_peteish-launch.sh 4 --model 600M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
-./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64
+
+./scripts/beaker/ladder_peteish-launch.sh 2 --model 190M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
+./scripts/beaker/ladder_peteish-launch.sh 2 --model 370M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
+./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64 --device_eval_batch_size 8
+./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64 --device_eval_batch_size 4
+
+./scripts/beaker/ladder_peteish-launch.sh 2 --model 190M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
+./scripts/beaker/ladder_peteish-launch.sh 2 --model 370M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
+./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64 --device_eval_batch_size 8
+./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64 --device_eval_batch_size 4
+
+./scripts/beaker/ladder_peteish-launch.sh 2 --model 190M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
+./scripts/beaker/ladder_peteish-launch.sh 2 --model 370M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
+./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64 --device_eval_batch_size 8
+./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64 --device_eval_batch_size 4
+
+./scripts/beaker/ladder_peteish-launch.sh 2 --model 190M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
+./scripts/beaker/ladder_peteish-launch.sh 2 --model 370M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
+./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64 --device_eval_batch_size 8
+./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64 --device_eval_batch_size 4
diff --git a/scripts/scaling/single_step.py b/scripts/scaling/single_step.py
index 6eff54d9a..9802dd2a3 100644
--- a/scripts/scaling/single_step.py
+++ b/scripts/scaling/single_step.py
@@ -11,7 +11,7 @@
 )
 from olmo.scaling.scaling_laws.utils import (
     get_final_configs,
-    get_final_data_by_name,
+    get_step1_data_by_name,
     get_task_sets,
     prettify,
     tasks,
@@ -23,25 +23,25 @@
 def parse_args():
     parser = argparse.ArgumentParser()
     parser.add_argument("-k", "--keys", nargs="+", default=[], help="Key(s) for tasks")
-    parser.add_argument(
-        "--num_to_avg", type=int, default=1, help="Number of final ckpts to average (for final loss fitting)"
-    )
+    parser.add_argument("--moving_avg", type=int, default=1, help="Moving average for bpb loss")
     parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file")
     parser.add_argument("-o", "--output-path", type=str, required=True, help="Path to write output figure")
     args = parser.parse_args()
 
+    args.keys = get_task_sets(args.keys)
+
     return args
 
 
-def fit_step12(data_by_name):
+def fit_step12(data_by_name, task_name):
     train_nds, train_ys = [], []
     for name, data in data_by_name.items():
         if data["mode"] == "train":
             train_nds += [[n, d] for n, d in zip(data["ns"], data["ds"])]
             train_ys += data["ys"]
 
-    p0 = [3.0, 5.0, 0.2, 0.3, 1.0, -0.5, 1.0, 3.0, 1.0]
-    bounds = [(0, None), (0, None), (0, 1), (0, 1), (0, None), (-0.9999, 0), (0, None), (0, None), (0, 1)]
+    p0 = [3.0, 5.0, 0.2, 0.3, 0.0, tasks[task_name].task_minimum - 1.0, 1.0]
+    bounds = [(0, 10), (0, 10), (0, 1), (0, 1), (-10, 10), (-0.9999, 0), (0, 1)]
     coefficients = get_coefficients_huber(
         train_nds,
         train_ys,
@@ -84,9 +84,11 @@ def predict_step12(data_by_name, coefficients):
 
 
 def str_combined_fit(coefficients):
-    a, b, alpha, beta, E, p, L0, k, q = coefficients
+    a, b, alpha, beta, E, p, q = coefficients
     A, B = np.exp(a), np.exp(b)
-    return f"L(N, D) = {A:.2f} / N^{alpha:.2f} + {B:.2f} / D^{beta:.2f} + {E:.2f}\nAcc(L) = {p:.2f} / (1 + e^(-{k:.2f} (L - {L0:.2f}))) + {q:.2f}"
+    return (
+        f"Acc(N, D) = {p:.2f} / (1 + e^-({A:.2f} / N^{alpha:.2f} \n + {B:.2f} / D^{beta:.2f} + {E:.2f})) + {q:.2f}"
+    )
 
 
 def plot_step12(
@@ -143,7 +145,7 @@ def plot_step12(
         )
 
     ax.set_xscale("log")
-    ax.legend(loc="lower right", ncols=1, fontsize=8)
+    ax.legend(ncols=1, fontsize=7)
     ax.set_xlabel("Tokens (D)")
     ax.set_ylabel("Task accuracy")
     ax.set_title(
@@ -154,11 +156,8 @@ def plot_step12(
 
 def main():
     args = parse_args()
-
     configs = get_final_configs(args.config_path)
 
-    args.keys = get_task_sets(args.keys)
-
     sns.set_style("whitegrid")
     num_tasks = len(args.keys)
     num_cols = min(4, num_tasks)
@@ -168,12 +167,10 @@ def main():
     results = "Task Name | Actual Value | Predicted Value | Relative Error"
 
     for i, task_name in enumerate(args.keys):
-        task = tasks[task_name]
-        keys = task.get_accuracy_keys()
-        data_by_name = get_final_data_by_name(configs, keys, num_to_avg=args.num_to_avg)
+        data_by_name = get_step1_data_by_name(configs, task_name, y_metric="rc_acc", moving_avg=args.moving_avg)
 
         # fit the parameters
-        coefficients = fit_step12(data_by_name)
+        coefficients = fit_step12(data_by_name, task_name)
 
         # make predictions
         predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_step12(