Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Nov 19, 2024
1 parent 272ef4a commit 9b80ad9
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 35 deletions.
15 changes: 15 additions & 0 deletions olmo/scaling/scaling_laws/fitting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def get_coefficients(train_xs, train_ys, fitting_func, p0, bounds=(-np.inf, np.i
return coeffs


def get_std_errors(xs, ys, coefficients, cov, fitting_func, grad_fitting_func):
# Compute the Jacobian matrix
jacobian = np.zeros((len(xs), len(coefficients)))
for j, x in enumerate(xs):
jacobian[j] = grad_fitting_func(x, coefficients)

# Compute standard errors for predictions
intermediate = np.sum(jacobian @ cov @ jacobian.T, axis=1)
std_errors = np.sqrt(intermediate.clip(min=0.0))

return std_errors


# x = flops
# p[0] = A, p[1] = B, p[2] = E
def chinchilla_flops_fit(x, p):
Expand Down Expand Up @@ -315,6 +328,8 @@ def sigmoid(x, a, x0, k, b):
o = a / (1 + np.exp(-k * (x - x0))) + b
return o

def exponential_fit(x, a, b, c):
return a * np.exp(b * x) + c

# Scipy minimize w/ Huber loss
def get_coefficients_huber(
Expand Down
69 changes: 59 additions & 10 deletions scripts/scaling/stacked.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse

import re
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
Expand Down Expand Up @@ -46,6 +46,8 @@ def parse_args():
)
parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file")
parser.add_argument("-o", "--output-path", type=str, required=True, help="Path to write output figure")
parser.add_argument("--target_n", type=str, required=False, default="")
parser.add_argument("--target_d", type=str, required=False, default="")
args = parser.parse_args()

return args
Expand All @@ -57,6 +59,45 @@ def str_chinchilla_n_d_fit(coefficients):
return f"L(N, D) = {A:.2f} / N^{alpha:.2f} + {B:.2f} / D^{beta:.2f} + {E:.2f}"


# These are updated with actual Peteish count
MODEL_PARAMS = {
"190M": 190354176,
"370M": 371262464,
"600M": 597382464,
"760M": 758220288,
"1B": 1279395840,
"3B": 3169537280,
"7B": 6887575552,
"13B": 13202396160,
}


_number_unit_re = re.compile(r"^([0-9]+)([a-zA-Z]+)$")
_run_name_re = re.compile(r"^([^-]+)-([^-]+)-([^-]+)$")


def parse_size(size: str) -> int:
return MODEL_PARAMS[size]


def parse_length(length: str, model_size: int) -> int:
length_in_tokens, length_unit = _number_unit_re.match(length.strip().upper()).groups() # type: ignore
length_in_tokens = int(length_in_tokens)
if length_unit == "C" or length_unit == "XC":
length_in_tokens *= 20 * model_size
elif length_unit == "K":
length_in_tokens *= 1000
elif length_unit == "M":
length_in_tokens *= 1000000
elif length_unit == "B":
length_in_tokens *= 1000000000
elif length_unit == "T":
length_in_tokens *= 1000000000000
else:
raise ValueError(f"Could not parse length '{args.length}'")
return length_in_tokens


def main():
args = parse_args()

Expand All @@ -66,9 +107,16 @@ def main():

sns.set_style("whitegrid")

if args.target_n:
pred_n = parse_size(args.target_n)
pred_d = parse_length(args.target_d, pred_n)


num_tasks = len(args.keys)
fig, axes = plt.subplots(num_tasks, 3, figsize=(6 * 3, 4.5 * num_tasks), squeeze=False)

accs = 0

results = "Task Name | Loss Error | Accuracy Error | Stacked Accuracy Error"

for r, task_name in enumerate(args.keys):
Expand Down Expand Up @@ -138,7 +186,7 @@ def main():
ax.scatter(d, y, color=config.color, marker=MARKERS.get(length, "*"), s=50)

predicted_data = predicted_data_by_name[name]
for d, y, y_pred in zip(data["ds"], data["ys"], predicted_data["xs"]):
for d, y, y_pred in zip(data["ds"], data["xs"], predicted_data["xs"]):
rel_error = (y_pred - y) / y
ax.annotate(
f"{prettify(rel_error)}",
Expand Down Expand Up @@ -175,6 +223,8 @@ def main():
ax.set_ylabel("Loss")
ax.set_title(task_name)

step1_coefficients = coefficients

# Step 2

data_by_name = get_step2_data_by_name(
Expand Down Expand Up @@ -309,6 +359,11 @@ def main():
ax.set_ylabel("Task accuracy")
ax.set_title(task_name)

if args.target_n:
predicted_loss = chinchilla_n_d_fit([pred_n, pred_d], step1_coefficients)
predicted_acc = sigmoid(predicted_loss, *coefficients)
print(f"Predicted {task_name} acc for {args.target_n}-{args.target_d}: {predicted_acc:.3f}")

# Stacked plot

ax = axes[r][2]
Expand All @@ -334,14 +389,8 @@ def main():

print(results)

# y_1b_3T = chinchilla_flops_fit([1176832000, 3e12], coefficients)
# print(f"Predicted final loss for 1b-3T: {y_1b_3T:.3f}")
# y_7b_2T = chinchilla_flops_fit([6682316800, 2e12], coefficients)
# print(f"Predicted final loss for 7b-2T: {y_7b_2T:.3f}")
# y_7b_3T = chinchilla_flops_fit([6682316800, 3e12], coefficients)
# print(f"Predicted final loss for 7b-3T: {y_7b_3T:.3f}")
# y_13b_5T = chinchilla_flops_fit([13e9, 5e12], coefficients)
# print(f"Predicted final loss for 13b-5T: {y_13b_5T:.3f}")




if __name__ == "__main__":
Expand Down
15 changes: 8 additions & 7 deletions scripts/scaling/step1.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def fit_step1(data_by_name, y_metric):
if y_metric == "rc_bpb":
p0 = [3.0, 6.0, 0.1, 0.2, 1.0]
bounds = [(0, None), (0, None), (0, None), (0, None), (0, None)]
coefficients = get_coefficients_huber(
coefficients, cov = get_coefficients_huber(
train_nds,
train_ys,
chinchilla_n_d_fit,
Expand All @@ -59,11 +59,12 @@ def fit_step1(data_by_name, y_metric):
bounds=bounds,
max_iter=1000000,
disp=False,
return_cov=True,
)
elif y_metric == "rc_acc":
p0 = [2.0, 2.0, 0.2, 0.2, 1.0]
bounds = [(0, None), (0, None), (0, None), (0, None), (0, None)]
coefficients = get_coefficients_huber(
bounds = [(0, None), (0, None), (0, None), (None, None), (None, None)]
coefficients, cov = get_coefficients_huber(
train_nds,
train_ys,
chinchilla_n_d_negated_fit,
Expand All @@ -72,11 +73,12 @@ def fit_step1(data_by_name, y_metric):
bounds=bounds,
max_iter=1000000,
disp=False,
return_cov=True
)
else:
raise ValueError(f"Unknown y_metric: {y_metric}")

return coefficients
return coefficients, cov


def predict_step1(data_by_name, coefficients, y_metric):
Expand Down Expand Up @@ -206,16 +208,15 @@ def main():
)

# fit the parameters
coefficients = fit_step1(data_by_name, args.y_metric)
coefficients, cov = fit_step1(data_by_name, args.accuracy)

# make predictions
predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_step1(
data_by_name, coefficients, y_metric=args.y_metric
)
results += f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} | {prettify(rel_error)}"

plot_step1(
configs,
plot_step1(configs,
data_by_name,
predicted_data_by_name,
plotted_predicted_data_by_name,
Expand Down
27 changes: 9 additions & 18 deletions scripts/scaling/step2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import seaborn as sns

from olmo.scaling.scaling_laws.fitting_functions import get_coefficients, sigmoid
from olmo.scaling.scaling_laws.fitting_functions import get_coefficients, sigmoid, get_std_errors
from olmo.scaling.scaling_laws.utils import (
get_final_configs,
get_step2_data_by_name,
Expand Down Expand Up @@ -106,21 +106,7 @@ def main():
"ys": [sigmoid(x, *coefficients) for x in xs],
}

# Compute standard errors for prediction
# Compute the Jacobian matrix of partial derivatives with respect to parameters
jacobian = np.zeros((len(plotted_predicted_data["xs"]), len(coefficients)))
for j, x_val in enumerate(plotted_predicted_data["xs"]):
# Partial derivatives
jacobian[j, 0] = 1 / (1 + np.exp(-k * (x_val - x0)))
jacobian[j, 1] = a * k * np.exp(-k * (x_val - x0)) / (1 + np.exp(-k * (x_val - x0))) ** 2
jacobian[j, 2] = a * (x_val - x0) * np.exp(-k * (x_val - x0)) / (1 + np.exp(-k * (x_val - x0))) ** 2
jacobian[j, 3] = 1

# Compute standard errors for predictions
intermediate = np.sum(jacobian @ cov @ jacobian.T, axis=1)
# TODO: DANGER, this approximation may be bad.
std_errors = np.sqrt(intermediate.clip(min=0.0))
# std_errors = np.sqrt(np.abs(intermediate))
std_errors = get_std_errors(plotted_predicted_data["xs"], plotted_predicted_data["ys"], coefficients, cov)

# Compute prediction intervals
plotted_y_lower = plotted_predicted_data["ys"] - 1.96 * std_errors
Expand All @@ -144,6 +130,11 @@ def main():
)
for x, y, y_pred in zip(data["xs"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y
std_error = get_std_errors([x], [y_pred], coefficients, cov)[0]
y_lower = y_pred - 1.96 * std_error
y_upper = y_pred + 1.96 * std_error
rel_error_lower = (y_lower - y) / y
rel_error_upper = (y_upper - y) / y

if config.mode == "train":
unsigned_rel_errs.append(abs(rel_error))
Expand All @@ -159,7 +150,7 @@ def main():
color=config.color,
)
results += (
f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} | {prettify(rel_error)}"
f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} +/- {prettify(1.96 * std_error, False)} | {prettify(rel_error)}"
)
avg_unsigned_rel_err = np.mean(unsigned_rel_errs)

Expand All @@ -174,7 +165,7 @@ def main():

ax.fill_between(
plotted_predicted_data["xs"], plotted_y_lower, plotted_y_upper, color="pink", alpha=0.3
) # , label="95% Prediction Interval")
)

ax.legend(loc="lower right", ncols=1, fontsize=8)
ax.set_xlabel("Task loss")
Expand Down

0 comments on commit 9b80ad9

Please sign in to comment.