Skip to content

Commit

Permalink
Merge branch 'ladder-1xC-task_ce' into ladder-1xC
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeCreator committed Nov 25, 2024
2 parents e643734 + c04ae4a commit 9bd336b
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 72 deletions.
43 changes: 38 additions & 5 deletions olmo/scaling/scaling_laws/fitting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,20 @@ def chinchilla_n_d_fit_e(x, p0, p1, p2, p3, p4):
# p[0] = a = log(A), p[1] = b = log(B), p[2] = alpha, p[3] = beta, p[4] = E
def chinchilla_n_d_fit(x, p):
# return e**a / x[0]**alpha + e**b / x[1]**beta + E
return np.exp(p[0]) / x[0] ** p[2] + np.exp(p[1]) / x[1] ** p[3] + p[4]

return np.exp(p[0] - np.log(x[0]) * p[2]) + np.exp(p[1] - np.log(x[1]) * p[3]) + p[4]
# return np.exp(p[0]) / x[0] ** p[2] + np.exp(p[1]) / x[1] ** p[3] + p[4]


def grad_chinchilla_n_d_fit(x, p):
grad_a = np.exp(p[0]) / x[0] ** p[2]
grad_b = np.exp(p[1]) / x[1] ** p[3]
grad_alpha = np.exp(p[0]) * (-np.log(x[0])) / x[0] ** p[2]
grad_beta = np.exp(p[1]) * (-np.log(x[1])) / x[1] ** p[3]
grad_a = np.exp(p[0] - np.log(x[0]) * p[2])
grad_b = np.exp(p[1] - np.log(x[1]) * p[3])
grad_alpha = np.exp(p[0] - np.log(x[0]) * p[2]) * (-np.log(x[0]))
grad_beta = np.exp(p[1] - np.log(x[1]) * p[3]) * (-np.log(x[1]))
# grad_a = np.exp(p[0]) / x[0] ** p[2]
# grad_b = np.exp(p[1]) / x[1] ** p[3]
# grad_alpha = np.exp(p[0]) * (-np.log(x[0])) / x[0] ** p[2]
# grad_beta = np.exp(p[1]) * (-np.log(x[1])) / x[1] ** p[3]
grad_E = 1
return [grad_a, grad_b, grad_alpha, grad_beta, grad_E]

Expand Down Expand Up @@ -327,6 +333,33 @@ def sigmoid(x, a, x0, k, b):
return o


def log_sigmoid(x, a, x0, k, c=0.0):
y = np.log(1 - 1/(1 + np.exp(-k * (x - x0))) + c)
o = (-a) * y + 1
return o


def log_sigmoid_fit(x, p):
y = np.log(1 - 1/(1 + np.exp(-p[2] * (x - p[1]))))
o = (-p[0]) * y + 1
return o


def grad_log_sigmoid_fit(x, p):
# Pre-compute common terms
sigmoid = 1 / (1 + np.exp(-p[2] * (x - p[1])))
log_term = np.log(1 - sigmoid) # Inside of the log function
d_log_term = -sigmoid # Derivative of log(1 - sigmoid)

# Gradients
grad_a = -log_term # Derivative w.r.t. p[0]
grad_x0 = (-p[0]) * d_log_term * (-p[2]) # Derivative w.r.t. p[1]
grad_k = (-p[0]) * d_log_term * ((x - p[1])) # Derivative w.r.t. p[2]
grad_b = 1 # Derivative w.r.t. p[3]

return [grad_a, grad_x0, grad_k]


def sigmoid_fit(x, p):
o = p[0] / (1 + np.exp(-p[2] * (x - p[1]))) + p[3]
return o
Expand Down
33 changes: 32 additions & 1 deletion olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class DownstreamTaskPrediction:
task_accuracy_key: Union[str, List[str]]
task_mc_loss_key: Union[str, List[str]]
task_mc_accuracy_key: Union[str, List[str]]
display_name: str
task_soft_loss_key: Union[str, List[str]] = ""
task_log_soft_loss_key: Union[str, List[str]] = ""
task_minimum: float = 0.25
Expand Down Expand Up @@ -159,6 +160,18 @@ def get_log_soft_loss_keys(self):
core_small_names = ["hellaswag", "arc_challenge", "arc_easy", "piqa", "csqa", "socialiqa", "openbookqa"]
mmlu_names = ["mmlu_stem", "mmlu_humanities", "mmlu_social_sciences", "mmlu_other"]

display_names = {
"hellaswag": "HellaSwag",
"arc_easy": "ARC-Easy",
"arc_challenge": "ARC-Challenge",
"boolq": "BoolQ",
"csqa": "CommonsenseQA",
"openbookqa": "OpenBookQA",
"piqa": "PIQA",
"socialiqa": "Social IQa",
"winogrande": "Winogrande",
}

core_5shot_tasks: Dict[str, DownstreamTaskPrediction] = {
f"{key}_5shot": DownstreamTaskPrediction(
task_loss_key=f"eval/downstream_bpb/{key}_rc_5shot_bpb_bpb",
Expand All @@ -171,6 +184,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=f"eval/downstream/{key}_mc_5shot_acc",
task_minimum=minimums_rc.get(key, 0.25),
task_maximum=maximums_rc.get(key, 1.0),
display_name=display_names.get(key, key),
)
for key in core_names
}
Expand All @@ -186,6 +200,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=f"eval/downstream/{key}_mc_5shot_acc",
task_minimum=minimums_rc.get(key, 0.25),
task_maximum=maximums_rc.get(key, 1.0),
display_name=display_names.get(key, key),
)
for key in core_small_names
}
Expand All @@ -204,6 +219,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=[f"eval/downstream/{key}_mc_5shot_acc" for key in core_small_names],
task_minimum=0.25,
task_maximum=1.0,
display_name="core_small_avg",
)
}
mmlu_var_tasks: Dict[str, DownstreamTaskPrediction] = {
Expand All @@ -214,6 +230,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=[f"eval/downstream/{key}_mc_5shot_len_norm" for key in mmlu_names],
task_minimum=0.25,
task_maximum=1.0, # 0.9,
display_name="MMLU"
)
}
mmlu_subset_var_tasks: Dict[str, DownstreamTaskPrediction] = {
Expand All @@ -224,6 +241,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=f"eval/downstream/{key}_mc_5shot_len_norm",
task_minimum=minimums_rc.get(key, 0.25),
task_maximum=maximums_rc.get(key, 0.9),
display_name="MMLU"
)
for key in mmlu_names
}
Expand Down Expand Up @@ -296,6 +314,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=f"eval/downstream/{key}_mc_5shot_acc",
task_minimum=v2_minimums_rc.get(key, 0.25),
task_maximum=v2_maximums_rc.get(key, 1.0),
display_name=display_names.get(key.removesuffix("_val").removesuffix("_test"), key),
)
for key in v2_core_names
}
Expand All @@ -313,6 +332,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=f"eval/downstream/{key}_mc_5shot_acc",
task_minimum=v2_minimums_rc.get(key, 0.25),
task_maximum=v2_maximums_rc.get(key, 1.0),
display_name=display_names.get(key.removesuffix("_val").removesuffix("_test"), key),
)
for key in v2_core_small_names
}
Expand All @@ -327,6 +347,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=[f"eval/downstream/{key}_mc_5shot_len_norm" for key in v2_mmlu_val_names],
task_minimum=v2_minimums_rc.get("mmlu_avg_val", 0.25),
task_maximum=v2_maximums_rc.get("mmlu_avg_val", 1.0),
display_name="MMLU"
)
}

Expand All @@ -340,6 +361,7 @@ def get_log_soft_loss_keys(self):
task_mc_accuracy_key=[f"eval/downstream/{key}_mc_5shot_len_norm" for key in v2_mmlu_test_names],
task_minimum=v2_minimums_rc.get("mmlu_avg_test", 0.25),
task_maximum=v2_maximums_rc.get("mmlu_avg_test", 1.0),
display_name="MMLU"
)
}

Expand Down Expand Up @@ -744,7 +766,7 @@ def get_flops_data_by_name(configs, keys, num_to_avg=1):
def moving_average(arr, n):
ret = np.cumsum(arr, dtype=float)
ret[n:] = ret[n:] - ret[:-n]
return np.concat([ret[: n - 1] / np.arange(1, n), ret[n - 1 :] / n])
return np.concatenate([ret[: n - 1] / np.arange(1, n), ret[n - 1 :] / n])


def get_length(path):
Expand All @@ -760,6 +782,9 @@ def get_step2_data_by_name(
task = tasks[task_name]
if x_metric == "rc_bpb":
loss_keys = task.get_loss_keys()
elif x_metric == "rc_soft_log":
loss_keys = task.get_accuracy_keys()
loss_keys = [key.replace("/downstream/", "/downstream_soft_log/").replace("_len_norm", "_soft_log") for key in loss_keys]
elif x_metric == "c4":
loss_keys = ["eval/c4_en-validation/CrossEntropyLoss"]
else:
Expand Down Expand Up @@ -805,10 +830,16 @@ def get_step2_data_by_name(
[float(row[key]) for key in loss_keys],
weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in loss_keys],
)
if x_metric == "rc_soft_log":
x *= -1

y = np.average(
[float(row[key]) for key in accuracy_keys],
weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in accuracy_keys],
)
if y_metric == "rc_soft_log":
y *= -1

xs.append(x)
ys.append(y)
ds.append(d)
Expand Down
51 changes: 37 additions & 14 deletions scripts/scaling/step1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
get_step1_data_by_name,
get_task_sets,
prettify,
tasks
)

MARKERS = ["s", "P", "p", "*", "o"]

FONTSIZE=10

def parse_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -198,12 +199,14 @@ def plot_step1(
data["ys"],
color=config.color,
linestyle="--",
alpha=0.7,
linewidth=1.5,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)

# plot the actual and predicted data
unsigned_rel_errors = []
num_eval_annotation = 0
for name, data in data_by_name.items():
config = configs[name]
predicted_data = predicted_data_by_name[name]
Expand Down Expand Up @@ -235,28 +238,31 @@ def plot_step1(
f"{abs(rel_error) * 100:.1f}%",
(d, y),
textcoords="offset points",
xytext=(3, 3),
xytext=(10, 1 - 10*num_eval_annotation),
ha="left",
va="bottom",
fontsize=10,
fontsize=FONTSIZE,
color=config.color,
)
num_eval_annotation += 1
avg_unsigned_rel_error = np.mean(unsigned_rel_errors)

ax.set_xscale("log")
ax.legend(loc="upper right", ncols=1, fontsize=8)
ax.set_xlabel("Tokens (D)")
ax.legend(loc="upper right", ncols=1, fontsize=FONTSIZE)
ax.set_xlabel("Tokens (D)", fontsize=FONTSIZE)

if y_metric == "rc_bpb":
ax.set_ylabel("Task loss")
ax.set_ylabel("Task loss", fontsize=FONTSIZE)
elif y_metric == "rc_acc":
ax.set_ylabel("Task RC accuracy")
ax.set_ylabel("Task RC accuracy", fontsize=FONTSIZE)
elif y_metric == "c4":
ax.set_ylabel("C4 loss")
ax.set_ylabel("C4 loss", fontsize=FONTSIZE)
else:
raise ValueError(f"Unknown y_metric: {y_metric}")
ax.set_title(
f"{task_name}\n{fit_str}\navg rel error on fitting = {avg_unsigned_rel_error * 100:.2f}%",
fontsize=9,
f"{tasks[task_name].display_name} ({avg_unsigned_rel_error * 100:.2f}%)",
fontsize=FONTSIZE,
fontweight="bold",
)


Expand All @@ -266,13 +272,13 @@ def main():

sns.set_style("whitegrid")
num_tasks = len(args.keys)
num_cols = min(3, num_tasks)
num_cols = min(4, num_tasks)
num_rows = (num_tasks + num_cols - 1) // num_cols

fitting_error = 0

if args.output_path:
fig, axes = plt.subplots(num_rows, num_cols, figsize=(3.75 * num_cols, 3.25 * num_rows), squeeze=False)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(2.75 * num_cols, 2.25 * num_rows), squeeze=False)

results = "Task Name | Actual Value | Predicted Value | Relative Error"

Expand Down Expand Up @@ -313,9 +319,26 @@ def main():
axes[i // num_cols][i % num_cols],
)

handles, labels = axes[-1][-1].get_legend_handles_labels()
# delete x-axis labels for all but the bottom row
for i in range(num_cols):
for j in range(num_rows):
if j != num_rows - 1:
axes[j][i].set_xlabel("")
if i != 0:
axes[j][i].set_ylabel("")

axes[j][i].legend().remove()

fig.tight_layout()
legend = fig.legend(handles, labels, loc='upper center',
ncol=10, fontsize=FONTSIZE, bbox_to_anchor=(0.5, 1.07),
handletextpad=0.3, columnspacing=0.7)
for handle in legend.legend_handles:
handle.set_alpha(1.0)

if args.output_path:
fig.tight_layout()
fig.savefig(args.output_path, dpi=300)
fig.savefig(args.output_path, dpi=300, bbox_inches='tight')

print(results)
print("Total fitting error: ", prettify(fitting_error / num_tasks))
Expand Down
Loading

0 comments on commit 9bd336b

Please sign in to comment.