Skip to content

Commit

Permalink
figure 1
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Nov 26, 2024
1 parent 5f49408 commit 67cc7c4
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 334 deletions.
7 changes: 4 additions & 3 deletions olmo/scaling/scaling_laws/fitting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def get_std_errors(xs, ys, coefficients, cov, fitting_func, grad_fitting_func):
jacobian[j] = grad_fitting_func(x, coefficients)

# Compute standard errors for predictions
intermediate = np.sum(jacobian @ cov @ jacobian.T, axis=1)
# intermediate = np.sum(jacobian @ cov @ jacobian.T, axis=1)
intermediate = np.diag(jacobian @ cov @ jacobian.T)
std_errors = np.sqrt(intermediate.clip(min=0.0))

return std_errors
Expand Down Expand Up @@ -334,13 +335,13 @@ def sigmoid(x, a, x0, k, b):


def log_sigmoid(x, a, x0, k, c=0.0):
y = np.log(1 - 1/(1 + np.exp(-k * (x - x0))) + c)
y = np.log(1 - 1 / (1 + np.exp(-k * (x - x0))) + c)
o = (-a) * y + 1
return o


def log_sigmoid_fit(x, p):
y = np.log(1 - 1/(1 + np.exp(-p[2] * (x - p[1]))))
y = np.log(1 - 1 / (1 + np.exp(-p[2] * (x - p[1]))))
o = (-p[0]) * y + 1
return o

Expand Down
31 changes: 18 additions & 13 deletions olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=[f"eval/downstream/{key}_mc_5shot_len_norm" for key in mmlu_names],
task_minimum=0.25,
task_maximum=1.0, # 0.9,
display_name="MMLU"
display_name="MMLU",
)
}
mmlu_subset_var_tasks: Dict[str, DownstreamTaskPrediction] = {
Expand All @@ -241,7 +241,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=f"eval/downstream/{key}_mc_5shot_len_norm",
task_minimum=minimums_rc.get(key, 0.25),
task_maximum=maximums_rc.get(key, 0.9),
display_name="MMLU"
display_name="MMLU",
)
for key in mmlu_names
}
Expand Down Expand Up @@ -347,7 +347,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=[f"eval/downstream/{key}_mc_5shot_len_norm" for key in v2_mmlu_val_names],
task_minimum=v2_minimums_rc.get("mmlu_avg_val", 0.25),
task_maximum=v2_maximums_rc.get("mmlu_avg_val", 1.0),
display_name="MMLU"
display_name="MMLU",
)
}

Expand All @@ -361,7 +361,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=[f"eval/downstream/{key}_mc_5shot_len_norm" for key in v2_mmlu_test_names],
task_minimum=v2_minimums_rc.get("mmlu_avg_test", 0.25),
task_maximum=v2_maximums_rc.get("mmlu_avg_test", 1.0),
display_name="MMLU"
display_name="MMLU",
)
}

Expand Down Expand Up @@ -707,11 +707,13 @@ def get_step1_data_by_name(configs, task_name, y_metric="rc_bpb", moving_avg=1):
keys = ["eval/c4_en-validation/CrossEntropyLoss"]
elif y_metric == "rc_soft_log":
keys = task.get_accuracy_keys()
keys = [key.replace("/downstream/", "/downstream_soft_log/").replace("_len_norm", "_soft_log") for key in keys]
keys = [
key.replace("/downstream/", "/downstream_soft_log/").replace("_len_norm", "_soft_log") for key in keys
]
else:
raise ValueError(f"Invalid y_metric: {y_metric}")

data_by_name: Dict = defaultdict(lambda: {"ns": [], "ds": [], "ys": [], "ls": [], "fs": []})
data_by_name: Dict = defaultdict(lambda: {"ns": [], "ds": [], "xs": [], "ls": [], "fs": []})
for name, config in configs.items():
n = config.n
for path in config.paths:
Expand All @@ -720,27 +722,27 @@ def get_step1_data_by_name(configs, task_name, y_metric="rc_bpb", moving_avg=1):
reader = csv.DictReader(file_ref)
rows = [row for row in reader]
rows = rows[-moving_avg:]
ds, ys, fs = [], [], []
ds, xs, fs = [], [], []
for row in rows:
if "throughput/total_tokens" in row:
d = int(float(row["throughput/total_tokens"]))
else:
d = int(float(row["_step"])) * int(float(row["batch_size_in_tokens"]))
f = float(d * MODEL_FLOPS[name.split("-")[0]])
y = np.average(
x = np.average(
[float(row[key]) for key in keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in keys]
)
if y_metric == "rc_soft_log":
y *= -1
x *= -1
ds.append(d)
ys.append(y)
xs.append(x)
fs.append(f)
d = ds[-1]
y = np.mean(ys)
x = np.mean(xs)
f = fs[-1]
data_by_name[name]["ns"].append(n)
data_by_name[name]["ds"].append(d)
data_by_name[name]["ys"].append(y)
data_by_name[name]["xs"].append(x)
data_by_name[name]["ls"].append(length)
data_by_name[name]["fs"].append(f)
data_by_name[name]["mode"] = config.mode
Expand Down Expand Up @@ -794,7 +796,10 @@ def get_step2_data_by_name(
loss_keys = task.get_loss_keys()
elif x_metric == "rc_soft_log":
loss_keys = task.get_accuracy_keys()
loss_keys = [key.replace("/downstream/", "/downstream_soft_log/").replace("_len_norm", "_soft_log") for key in loss_keys]
loss_keys = [
key.replace("/downstream/", "/downstream_soft_log/").replace("_len_norm", "_soft_log")
for key in loss_keys
]
elif x_metric == "c4":
loss_keys = ["eval/c4_en-validation/CrossEntropyLoss"]
else:
Expand Down
10 changes: 8 additions & 2 deletions scripts/scaling/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from step1 import fit_step1
from step2 import fit_step2

from olmo.scaling.scaling_laws.fitting_functions import chinchilla_n_d_fit, sigmoid, log_sigmoid
from olmo.scaling.scaling_laws.fitting_functions import (
chinchilla_n_d_fit,
log_sigmoid,
sigmoid,
)
from olmo.scaling.scaling_laws.utils import (
get_final_configs,
get_step1_data_by_name,
Expand All @@ -31,7 +35,9 @@ def parse_args():
parser.add_argument(
"-k", "--keys", nargs="+", default=[], help="For avg metrics. Use one of [all-val-lm, all-bpb]"
)
parser.add_argument("-x", "--x_metric", default="rc_bpb", choices=["rc_bpb", "c4", "rc_soft_log"], help="Metric as input")
parser.add_argument(
"-x", "--x_metric", default="rc_bpb", choices=["rc_bpb", "c4", "rc_soft_log"], help="Metric as input"
)
parser.add_argument(
"-y", "--y_metric", default="rc_acc", choices=["rc_acc", "mc_acc"], help="Metric to predict"
)
Expand Down
17 changes: 12 additions & 5 deletions scripts/scaling/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def plot_single_step(
f"{abs(rel_error * 100):.1f}%",
(d, y_pred),
textcoords="offset points",
xytext=(10, -5 + 10*num_eval_annotation),
xytext=(10, -5 + 10 * num_eval_annotation),
ha="left",
va="bottom",
fontsize=FONTSIZE,
Expand Down Expand Up @@ -216,13 +216,20 @@ def main():
axes[j][i].legend().remove()

fig.tight_layout(w_pad=0.01)
legend = fig.legend(handles, labels, loc='upper center',
ncol=10, fontsize=FONTSIZE, bbox_to_anchor=(0.5, 1.07),
handletextpad=0.3, columnspacing=0.7)
legend = fig.legend(
handles,
labels,
loc="upper center",
ncol=10,
fontsize=FONTSIZE,
bbox_to_anchor=(0.5, 1.07),
handletextpad=0.3,
columnspacing=0.7,
)
for handle in legend.legend_handles:
handle.set_alpha(1.0)

fig.savefig(args.output_path, dpi=300, bbox_inches='tight')
fig.savefig(args.output_path, dpi=300, bbox_inches="tight")

print(results)

Expand Down
Loading

0 comments on commit 67cc7c4

Please sign in to comment.