From 0b4800a9bdf9c607b12da19f43d16e0f213d6127 Mon Sep 17 00:00:00 2001 From: Jiacheng Liu Date: Sat, 16 Nov 2024 01:59:28 +0000 Subject: [PATCH] Support predicting MC --- olmo/scaling/scaling_laws/utils.py | 18 +++-- scripts/scaling/eval_bpb_mc.py | 118 +++++++++++++++++++++++++++++ scripts/scaling/predict.py | 10 ++- scripts/scaling/step2_mc.json | 20 +++++ 4 files changed, 159 insertions(+), 7 deletions(-) create mode 100644 scripts/scaling/eval_bpb_mc.py create mode 100644 scripts/scaling/step2_mc.json diff --git a/olmo/scaling/scaling_laws/utils.py b/olmo/scaling/scaling_laws/utils.py index d24e568f3..ef19bf063 100644 --- a/olmo/scaling/scaling_laws/utils.py +++ b/olmo/scaling/scaling_laws/utils.py @@ -124,6 +124,8 @@ def get_mc_accuracy_keys(self): "piqa": 0.5, "socialiqa": 1 / 3, "csqa": 0.2, + "boolq": 0.5, + "winogrande": 0.5, } maximums_rc: Dict[str, float] = {} # {"mmlu_stem": 0.9, "arc_easy": 0.85} @@ -144,7 +146,7 @@ def get_mc_accuracy_keys(self): mmlu_names = ["mmlu_stem", "mmlu_humanities", "mmlu_social_sciences", "mmlu_other"] core_5shot_tasks: Dict[str, DownstreamTaskPrediction] = { - f"{key}_rc_5shot": DownstreamTaskPrediction( + f"{key}_5shot": DownstreamTaskPrediction( task_loss_key=f"eval/downstream_bpb/{key}_rc_5shot_bpb_bpb", task_accuracy_key=f"eval/downstream/{key}_rc_5shot_len_norm" if key not in ["arc_easy", "winogrande", "boolq"] @@ -158,7 +160,7 @@ def get_mc_accuracy_keys(self): } core_small_5shot_tasks: Dict[str, DownstreamTaskPrediction] = { - f"{key}_rc_5shot": DownstreamTaskPrediction( + f"{key}_5shot": DownstreamTaskPrediction( task_loss_key=f"eval/downstream_bpb/{key}_rc_5shot_bpb_bpb", task_accuracy_key=f"eval/downstream/{key}_rc_5shot_len_norm" if key not in ["arc_easy", "winogrande", "boolq"] @@ -205,6 +207,8 @@ def get_task_sets(keys): keys = list(mmlu_subset_var_tasks.keys()) elif keys[0] == "main": keys = list(mmlu_var_tasks.keys()) + list(core_small_5shot_tasks.keys()) + elif keys[0] == "main_mc": + keys = ["mmlu_avg_var", "arc_challenge_5shot"] elif keys[0] == "all": keys = list(mmlu_var_tasks.keys()) + list(core_5shot_tasks.keys()) return keys @@ -328,6 +332,8 @@ def get_accuracy_keys(tasks: Dict[str, DownstreamTaskPrediction]) -> List[str]: "all-bpb": downstream_bpb, "c4": ["eval/c4_en-validation/CrossEntropyLoss"], } +for task_name, task in tasks.items(): + KEYS_BY_KEY[task_name] = task.task_loss_key if isinstance(task.task_loss_key, list) else [task.task_loss_key] WEIGHT_BY_KEY = { "eval/downstream_bpb/mmlu_stem_var_bpb_bpb": 0.215, @@ -358,9 +364,6 @@ def get_accuracy_keys(tasks: Dict[str, DownstreamTaskPrediction]) -> List[str]: "13b": 91335915520, } -for task_name, task in tasks.items(): - KEYS_BY_KEY[task_name] = task.task_loss_key if isinstance(task.task_loss_key, list) else [task.task_loss_key] - def prettify(rel_error, is_percentage=True): if is_percentage: @@ -508,7 +511,10 @@ def moving_average(arr, n): def get_length(path): - return path.split("/")[-1].split(".csv")[0].split("-")[1] + try: + return path.split("/")[-1].split(".csv")[0].split("-")[1] + except: + return "" def get_step2_data_by_name(configs, task_name, y_metric="rc_acc", moving_avg=1, skip_perc=0.0, last_n_points=-1): diff --git a/scripts/scaling/eval_bpb_mc.py b/scripts/scaling/eval_bpb_mc.py new file mode 100644 index 000000000..43a9fe3bb --- /dev/null +++ b/scripts/scaling/eval_bpb_mc.py @@ -0,0 +1,118 @@ +import json +import matplotlib.pyplot as plt +import numpy as np + +MODELS = [ + 'allenai/OLMo-7B-0724-hf', + # 'allenai/OLMo-1B-0724-hf', + # 'allenai/OLMo-7B-0424-hf', + 'allenai/OLMo-7B-hf', + 'allenai/OLMo-1B-hf', + 'meta-llama/Llama-3.2-3B', + 'meta-llama/Llama-3.2-1B', + # 'meta-llama/Llama-3.1-70B', + 'meta-llama/Llama-3.1-8B', + # 'meta-llama/Meta-Llama-3-70B', + 'meta-llama/Meta-Llama-3-8B', + # 'meta-llama/Llama-2-70b-hf', + # 'meta-llama/Llama-2-13b-hf', + # 'meta-llama/Llama-2-7b-hf', + # 'google/gemma-2-27b', + # 'google/gemma-2-9b', + # 'google/gemma-2-2b', + # 'google/gemma-7b', + # 'google/gemma-2b', + # 'Qwen/Qwen2.5-72B', + # 'Qwen/Qwen2.5-32B', + 'Qwen/Qwen2.5-14B', + 'Qwen/Qwen2.5-7B', + 'Qwen/Qwen2.5-3B', + 'Qwen/Qwen2.5-1.5B', + # 'Qwen/Qwen2-72B', + 'Qwen/Qwen2-7B', + 'Qwen/Qwen2-1.5B', + 'mistralai/Mistral-Nemo-Base-2407', + 'mistralai/Mistral-7B-v0.3', + 'mistralai/Mistral-7B-v0.1', +] + +COLOR_BY_MODEL_PREFIX = { + 'allenai': 'hotpink', + 'meta-llama/Llama-3.2': 'darkblue', + 'meta-llama/Llama-3.1': 'mediumblue', + 'meta-llama/Meta-Llama-3': 'royalblue', + 'meta-llama/Llama-2': 'cornflowerblue', + 'google/gemma-2-': 'darkgreen', + 'google/gemma-': 'forestgreen', + 'Qwen/Qwen2.5': 'darkviolet', + 'Qwen/Qwen2': 'violet', + 'mistralai': 'darkorange', +} +def get_color(model): + for prefix, color in COLOR_BY_MODEL_PREFIX.items(): + if model.startswith(prefix): + return color + return 'black' + +METRICS_BY_TASK = { + 'rc_rc_mmlu': [ + ('mmlu_stem_var_bpb', 'mmlu_stem_var_len_norm', 0.215), + ('mmlu_humanities_var_bpb', 'mmlu_humanities_var_len_norm', 0.335), + ('mmlu_social_sciences_var_bpb', 'mmlu_social_sciences_var_len_norm', 0.219), + ('mmlu_other_var_bpb', 'mmlu_other_var_len_norm', 0.231), + ], + 'rc_rc_hellaswag': [('hellaswag_rc_5shot_bpb', 'hellaswag_rc_5shot_len_norm', 1.0)], + 'rc_rc_arc-c': [('arc_challenge_rc_5shot_bpb', 'arc_challenge_rc_5shot_len_norm', 1.0)], + 'rc_rc_piqa': [('piqa_rc_5shot_bpb', 'piqa_rc_5shot_len_norm', 1.0)], + 'rc_rc_csqa': [('csqa_rc_5shot_bpb', 'csqa_rc_5shot_len_norm', 1.0)], + 'rc_rc_socialiqa': [('socialiqa_rc_5shot_bpb', 'socialiqa_rc_5shot_len_norm', 1.0)], + 'rc_mc_mmlu': [ + ('mmlu_stem_var_bpb', 'mmlu_stem_mc_5shot_len_norm', 0.215), + ('mmlu_humanities_var_bpb', 'mmlu_humanities_mc_5shot_len_norm', 0.335), + ('mmlu_social_sciences_var_bpb', 'mmlu_social_sciences_mc_5shot_len_norm', 0.219), + ('mmlu_other_var_bpb', 'mmlu_other_mc_5shot_len_norm', 0.231), + ], + 'rc_mc_hellaswag': [('hellaswag_rc_5shot_bpb', 'hellaswag_mc_5shot_acc', 1.0)], + 'rc_mc_arc-c': [('arc_challenge_rc_5shot_bpb', 'arc_challenge_mc_5shot_acc', 1.0)], + 'rc_mc_piqa': [('piqa_rc_5shot_bpb', 'piqa_mc_5shot_acc', 1.0)], + 'rc_mc_csqa': [('csqa_rc_5shot_bpb', 'csqa_mc_5shot_acc', 1.0)], + 'rc_mc_socialiqa': [('socialiqa_rc_5shot_bpb', 'socialiqa_mc_5shot_acc', 1.0)], + 'mc_mc_mmlu': [ + ('mmlu_stem_mc_5shot_bpb', 'mmlu_stem_mc_5shot_len_norm', 0.215), + ('mmlu_humanities_mc_5shot_bpb', 'mmlu_humanities_mc_5shot_len_norm', 0.335), + ('mmlu_social_sciences_mc_5shot_bpb', 'mmlu_social_sciences_mc_5shot_len_norm', 0.219), + ('mmlu_other_mc_5shot_bpb', 'mmlu_other_mc_5shot_len_norm', 0.231), + ], + 'mc_mc_hellaswag': [('hellaswag_mc_5shot_bpb', 'hellaswag_mc_5shot_acc', 1.0)], + 'mc_mc_arc-c': [('arc_challenge_mc_5shot_bpb', 'arc_challenge_mc_5shot_acc', 1.0)], + 'mc_mc_piqa': [('piqa_mc_5shot_bpb', 'piqa_mc_5shot_acc', 1.0)], + 'mc_mc_csqa': [('csqa_mc_5shot_bpb', 'csqa_mc_5shot_acc', 1.0)], + 'mc_mc_socialiqa': [('socialiqa_mc_5shot_bpb', 'socialiqa_mc_5shot_acc', 1.0)], +} + +fig, axs = plt.subplots(6, 3, figsize=(3 * 6, 6 * 4.5)) + +for i, (task, metrics) in enumerate(METRICS_BY_TASK.items()): + ax = axs[i % 6, i // 6] + for model in MODELS: + with open(f'wandb/eval_bpb_mc/{model.replace("/", "_")}.json') as f: + data = json.load(f) + try: + rc_bpb = np.average([data[f'eval/downstream_bpb/{metric[0]}_bpb'] for metric in metrics], weights=[metric[2] for metric in metrics]) + acc = np.average([data[f'eval/downstream/{metric[1]}'] for metric in metrics], weights=[metric[2] for metric in metrics]) + except KeyError: + continue + color = get_color(model) + ax.scatter([rc_bpb], [acc], color=color, s=100) + ax.annotate( + text=model.split('/')[1], + xy=(float(rc_bpb), float(acc)), + xytext=(8, -3), + textcoords='offset points', + fontsize=8, + ) + ax.set_xlabel(f'{task.split("_")[0]} bpb') + ax.set_ylabel(f'{task.split("_")[1]} acc') + ax.set_title(task) + +plt.savefig(f'wandb/eval_bpb_mc/all.png', dpi=300, bbox_inches='tight') diff --git a/scripts/scaling/predict.py b/scripts/scaling/predict.py index f9fda3812..9ee3772dd 100644 --- a/scripts/scaling/predict.py +++ b/scripts/scaling/predict.py @@ -1,3 +1,6 @@ +# python scripts/scaling/predict.py -k main -c scripts/scaling/final.json -n 6887575552 -d 3945065873408 -t 7b +# python scripts/scaling/predict.py -k main_mc -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 6887575552 -d 3945065873408 -t 7b-4T-final + import argparse import numpy as np @@ -34,6 +37,7 @@ def parse_args(): help="Percentage of intermediate ckpts to skip from the beginning (for loss to accuracy fitting)", ) parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file") + parser.add_argument("--step2-config-path", type=str, default=None, help="Path to config file for step2") parser.add_argument("-n", "--n", type=int, required=True, help="Model size of the target model") parser.add_argument("-d", "--d", type=int, required=True, help="Data size of the target model") parser.add_argument("-t", "--target-name", type=str, default=None, help="Path to the csv file of the target model") @@ -47,6 +51,10 @@ def parse_args(): def main(): args = parse_args() configs = get_final_configs(args.config_path) + if args.step2_config_path: + step2_configs = get_final_configs(args.step2_config_path) + else: + step2_configs = configs results = "Task Name | Prediction | Actual | Rel Error" @@ -59,7 +67,7 @@ def main(): # Step 2 step2_data_by_name = get_step2_data_by_name( - configs, task_name, y_metric=args.y_metric, moving_avg=args.moving_avg, skip_perc=args.skip_perc + step2_configs, task_name, y_metric=args.y_metric, moving_avg=args.moving_avg, skip_perc=args.skip_perc ) step2_coefficients, _ = fit_step2(step2_data_by_name, task_name, args.y_metric) diff --git a/scripts/scaling/step2_mc.json b/scripts/scaling/step2_mc.json new file mode 100644 index 000000000..aec1b09e0 --- /dev/null +++ b/scripts/scaling/step2_mc.json @@ -0,0 +1,20 @@ +{ + "7b-4T-70k-300k": { + "paths": [ + "wandb/peteish7_eval_full_70k-300k.csv" + ], + "mode": "train", + "n": 6887575552, + "label": "7b-4T-70k-300k", + "color": "darkviolet" + }, + "7b-4T-final": { + "paths": [ + "wandb/peteish7_eval_anneal.csv" + ], + "mode": "eval", + "n": 6887575552, + "label": "7b-4T-final", + "color": "magenta" + } +} \ No newline at end of file