Skip to content

Commit

Permalink
Support predicting MC
Browse files Browse the repository at this point in the history
  • Loading branch information
liujch1998 committed Nov 16, 2024
1 parent ab97e87 commit 0b4800a
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 7 deletions.
18 changes: 12 additions & 6 deletions olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def get_mc_accuracy_keys(self):
"piqa": 0.5,
"socialiqa": 1 / 3,
"csqa": 0.2,
"boolq": 0.5,
"winogrande": 0.5,
}

maximums_rc: Dict[str, float] = {} # {"mmlu_stem": 0.9, "arc_easy": 0.85}
Expand All @@ -144,7 +146,7 @@ def get_mc_accuracy_keys(self):
mmlu_names = ["mmlu_stem", "mmlu_humanities", "mmlu_social_sciences", "mmlu_other"]

core_5shot_tasks: Dict[str, DownstreamTaskPrediction] = {
f"{key}_rc_5shot": DownstreamTaskPrediction(
f"{key}_5shot": DownstreamTaskPrediction(
task_loss_key=f"eval/downstream_bpb/{key}_rc_5shot_bpb_bpb",
task_accuracy_key=f"eval/downstream/{key}_rc_5shot_len_norm"
if key not in ["arc_easy", "winogrande", "boolq"]
Expand All @@ -158,7 +160,7 @@ def get_mc_accuracy_keys(self):
}

core_small_5shot_tasks: Dict[str, DownstreamTaskPrediction] = {
f"{key}_rc_5shot": DownstreamTaskPrediction(
f"{key}_5shot": DownstreamTaskPrediction(
task_loss_key=f"eval/downstream_bpb/{key}_rc_5shot_bpb_bpb",
task_accuracy_key=f"eval/downstream/{key}_rc_5shot_len_norm"
if key not in ["arc_easy", "winogrande", "boolq"]
Expand Down Expand Up @@ -205,6 +207,8 @@ def get_task_sets(keys):
keys = list(mmlu_subset_var_tasks.keys())
elif keys[0] == "main":
keys = list(mmlu_var_tasks.keys()) + list(core_small_5shot_tasks.keys())
elif keys[0] == "main_mc":
keys = ["mmlu_avg_var", "arc_challenge_5shot"]
elif keys[0] == "all":
keys = list(mmlu_var_tasks.keys()) + list(core_5shot_tasks.keys())
return keys
Expand Down Expand Up @@ -328,6 +332,8 @@ def get_accuracy_keys(tasks: Dict[str, DownstreamTaskPrediction]) -> List[str]:
"all-bpb": downstream_bpb,
"c4": ["eval/c4_en-validation/CrossEntropyLoss"],
}
for task_name, task in tasks.items():
KEYS_BY_KEY[task_name] = task.task_loss_key if isinstance(task.task_loss_key, list) else [task.task_loss_key]

WEIGHT_BY_KEY = {
"eval/downstream_bpb/mmlu_stem_var_bpb_bpb": 0.215,
Expand Down Expand Up @@ -358,9 +364,6 @@ def get_accuracy_keys(tasks: Dict[str, DownstreamTaskPrediction]) -> List[str]:
"13b": 91335915520,
}

for task_name, task in tasks.items():
KEYS_BY_KEY[task_name] = task.task_loss_key if isinstance(task.task_loss_key, list) else [task.task_loss_key]


def prettify(rel_error, is_percentage=True):
if is_percentage:
Expand Down Expand Up @@ -508,7 +511,10 @@ def moving_average(arr, n):


def get_length(path):
return path.split("/")[-1].split(".csv")[0].split("-")[1]
try:
return path.split("/")[-1].split(".csv")[0].split("-")[1]
except:
return ""


def get_step2_data_by_name(configs, task_name, y_metric="rc_acc", moving_avg=1, skip_perc=0.0, last_n_points=-1):
Expand Down
118 changes: 118 additions & 0 deletions scripts/scaling/eval_bpb_mc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import json
import matplotlib.pyplot as plt
import numpy as np

MODELS = [
'allenai/OLMo-7B-0724-hf',
# 'allenai/OLMo-1B-0724-hf',
# 'allenai/OLMo-7B-0424-hf',
'allenai/OLMo-7B-hf',
'allenai/OLMo-1B-hf',
'meta-llama/Llama-3.2-3B',
'meta-llama/Llama-3.2-1B',
# 'meta-llama/Llama-3.1-70B',
'meta-llama/Llama-3.1-8B',
# 'meta-llama/Meta-Llama-3-70B',
'meta-llama/Meta-Llama-3-8B',
# 'meta-llama/Llama-2-70b-hf',
# 'meta-llama/Llama-2-13b-hf',
# 'meta-llama/Llama-2-7b-hf',
# 'google/gemma-2-27b',
# 'google/gemma-2-9b',
# 'google/gemma-2-2b',
# 'google/gemma-7b',
# 'google/gemma-2b',
# 'Qwen/Qwen2.5-72B',
# 'Qwen/Qwen2.5-32B',
'Qwen/Qwen2.5-14B',
'Qwen/Qwen2.5-7B',
'Qwen/Qwen2.5-3B',
'Qwen/Qwen2.5-1.5B',
# 'Qwen/Qwen2-72B',
'Qwen/Qwen2-7B',
'Qwen/Qwen2-1.5B',
'mistralai/Mistral-Nemo-Base-2407',
'mistralai/Mistral-7B-v0.3',
'mistralai/Mistral-7B-v0.1',
]

COLOR_BY_MODEL_PREFIX = {
'allenai': 'hotpink',
'meta-llama/Llama-3.2': 'darkblue',
'meta-llama/Llama-3.1': 'mediumblue',
'meta-llama/Meta-Llama-3': 'royalblue',
'meta-llama/Llama-2': 'cornflowerblue',
'google/gemma-2-': 'darkgreen',
'google/gemma-': 'forestgreen',
'Qwen/Qwen2.5': 'darkviolet',
'Qwen/Qwen2': 'violet',
'mistralai': 'darkorange',
}
def get_color(model):
for prefix, color in COLOR_BY_MODEL_PREFIX.items():
if model.startswith(prefix):
return color
return 'black'

METRICS_BY_TASK = {
'rc_rc_mmlu': [
('mmlu_stem_var_bpb', 'mmlu_stem_var_len_norm', 0.215),
('mmlu_humanities_var_bpb', 'mmlu_humanities_var_len_norm', 0.335),
('mmlu_social_sciences_var_bpb', 'mmlu_social_sciences_var_len_norm', 0.219),
('mmlu_other_var_bpb', 'mmlu_other_var_len_norm', 0.231),
],
'rc_rc_hellaswag': [('hellaswag_rc_5shot_bpb', 'hellaswag_rc_5shot_len_norm', 1.0)],
'rc_rc_arc-c': [('arc_challenge_rc_5shot_bpb', 'arc_challenge_rc_5shot_len_norm', 1.0)],
'rc_rc_piqa': [('piqa_rc_5shot_bpb', 'piqa_rc_5shot_len_norm', 1.0)],
'rc_rc_csqa': [('csqa_rc_5shot_bpb', 'csqa_rc_5shot_len_norm', 1.0)],
'rc_rc_socialiqa': [('socialiqa_rc_5shot_bpb', 'socialiqa_rc_5shot_len_norm', 1.0)],
'rc_mc_mmlu': [
('mmlu_stem_var_bpb', 'mmlu_stem_mc_5shot_len_norm', 0.215),
('mmlu_humanities_var_bpb', 'mmlu_humanities_mc_5shot_len_norm', 0.335),
('mmlu_social_sciences_var_bpb', 'mmlu_social_sciences_mc_5shot_len_norm', 0.219),
('mmlu_other_var_bpb', 'mmlu_other_mc_5shot_len_norm', 0.231),
],
'rc_mc_hellaswag': [('hellaswag_rc_5shot_bpb', 'hellaswag_mc_5shot_acc', 1.0)],
'rc_mc_arc-c': [('arc_challenge_rc_5shot_bpb', 'arc_challenge_mc_5shot_acc', 1.0)],
'rc_mc_piqa': [('piqa_rc_5shot_bpb', 'piqa_mc_5shot_acc', 1.0)],
'rc_mc_csqa': [('csqa_rc_5shot_bpb', 'csqa_mc_5shot_acc', 1.0)],
'rc_mc_socialiqa': [('socialiqa_rc_5shot_bpb', 'socialiqa_mc_5shot_acc', 1.0)],
'mc_mc_mmlu': [
('mmlu_stem_mc_5shot_bpb', 'mmlu_stem_mc_5shot_len_norm', 0.215),
('mmlu_humanities_mc_5shot_bpb', 'mmlu_humanities_mc_5shot_len_norm', 0.335),
('mmlu_social_sciences_mc_5shot_bpb', 'mmlu_social_sciences_mc_5shot_len_norm', 0.219),
('mmlu_other_mc_5shot_bpb', 'mmlu_other_mc_5shot_len_norm', 0.231),
],
'mc_mc_hellaswag': [('hellaswag_mc_5shot_bpb', 'hellaswag_mc_5shot_acc', 1.0)],
'mc_mc_arc-c': [('arc_challenge_mc_5shot_bpb', 'arc_challenge_mc_5shot_acc', 1.0)],
'mc_mc_piqa': [('piqa_mc_5shot_bpb', 'piqa_mc_5shot_acc', 1.0)],
'mc_mc_csqa': [('csqa_mc_5shot_bpb', 'csqa_mc_5shot_acc', 1.0)],
'mc_mc_socialiqa': [('socialiqa_mc_5shot_bpb', 'socialiqa_mc_5shot_acc', 1.0)],
}

fig, axs = plt.subplots(6, 3, figsize=(3 * 6, 6 * 4.5))

for i, (task, metrics) in enumerate(METRICS_BY_TASK.items()):
ax = axs[i % 6, i // 6]
for model in MODELS:
with open(f'wandb/eval_bpb_mc/{model.replace("/", "_")}.json') as f:
data = json.load(f)
try:
rc_bpb = np.average([data[f'eval/downstream_bpb/{metric[0]}_bpb'] for metric in metrics], weights=[metric[2] for metric in metrics])
acc = np.average([data[f'eval/downstream/{metric[1]}'] for metric in metrics], weights=[metric[2] for metric in metrics])
except KeyError:
continue
color = get_color(model)
ax.scatter([rc_bpb], [acc], color=color, s=100)
ax.annotate(
text=model.split('/')[1],
xy=(float(rc_bpb), float(acc)),
xytext=(8, -3),
textcoords='offset points',
fontsize=8,
)
ax.set_xlabel(f'{task.split("_")[0]} bpb')
ax.set_ylabel(f'{task.split("_")[1]} acc')
ax.set_title(task)

plt.savefig(f'wandb/eval_bpb_mc/all.png', dpi=300, bbox_inches='tight')
10 changes: 9 additions & 1 deletion scripts/scaling/predict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# python scripts/scaling/predict.py -k main -c scripts/scaling/final.json -n 6887575552 -d 3945065873408 -t 7b
# python scripts/scaling/predict.py -k main_mc -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 6887575552 -d 3945065873408 -t 7b-4T-final

import argparse

import numpy as np
Expand Down Expand Up @@ -34,6 +37,7 @@ def parse_args():
help="Percentage of intermediate ckpts to skip from the beginning (for loss to accuracy fitting)",
)
parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file")
parser.add_argument("--step2-config-path", type=str, default=None, help="Path to config file for step2")
parser.add_argument("-n", "--n", type=int, required=True, help="Model size of the target model")
parser.add_argument("-d", "--d", type=int, required=True, help="Data size of the target model")
parser.add_argument("-t", "--target-name", type=str, default=None, help="Path to the csv file of the target model")
Expand All @@ -47,6 +51,10 @@ def parse_args():
def main():
args = parse_args()
configs = get_final_configs(args.config_path)
if args.step2_config_path:
step2_configs = get_final_configs(args.step2_config_path)
else:
step2_configs = configs

results = "Task Name | Prediction | Actual | Rel Error"

Expand All @@ -59,7 +67,7 @@ def main():

# Step 2
step2_data_by_name = get_step2_data_by_name(
configs, task_name, y_metric=args.y_metric, moving_avg=args.moving_avg, skip_perc=args.skip_perc
step2_configs, task_name, y_metric=args.y_metric, moving_avg=args.moving_avg, skip_perc=args.skip_perc
)
step2_coefficients, _ = fit_step2(step2_data_by_name, task_name, args.y_metric)

Expand Down
20 changes: 20 additions & 0 deletions scripts/scaling/step2_mc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"7b-4T-70k-300k": {
"paths": [
"wandb/peteish7_eval_full_70k-300k.csv"
],
"mode": "train",
"n": 6887575552,
"label": "7b-4T-70k-300k",
"color": "darkviolet"
},
"7b-4T-final": {
"paths": [
"wandb/peteish7_eval_anneal.csv"
],
"mode": "eval",
"n": 6887575552,
"label": "7b-4T-final",
"color": "magenta"
}
}

0 comments on commit 0b4800a

Please sign in to comment.