Skip to content

Commit

Permalink
results for multiple tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Nov 11, 2024
1 parent caad5c3 commit 446107d
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 81 deletions.
55 changes: 55 additions & 0 deletions olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ class DownstreamTaskPrediction:
task_minimum: float = 0.25
task_maximum: float = 1.0

def get_loss_keys(self):
return self.task_loss_key if isinstance(self.task_loss_key, list) else [self.task_loss_key]

def get_accuracy_keys(self):
return self.task_accuracy_key if isinstance(self.task_accuracy_key, list) else [self.task_accuracy_key]



downstream_5_shot: Dict[str, DownstreamTaskPrediction] = {
f"{key}_rc_5shot": DownstreamTaskPrediction(
Expand Down Expand Up @@ -294,6 +301,13 @@ def get_accuracy_keys(tasks: Dict[str, DownstreamTaskPrediction]) -> List[str]:
KEYS_BY_KEY[task_name] = task.task_loss_key if isinstance(task.task_loss_key, list) else [task.task_loss_key]


def prettify(rel_error, is_percentage=True):
if is_percentage:
return f"{rel_error * 100:+.1f}%"
else:
return f"{rel_error:.2f}"


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -420,6 +434,42 @@ def get_flops_data_by_name(configs, keys, num_to_avg=1):
return data_by_name



def get_downstream_data_by_name(configs, keys, num_to_avg=-1):
# TODO: weight_by_key may not be working correctly for mmlu
loss_keys = tasks[keys].get_loss_keys()
accuracy_keys = tasks[keys].get_accuracy_keys()
data_by_name: Dict = defaultdict(lambda: {"xs": [], "ys": []})

for name, config in configs.items():
n = config.n
for path in config.paths:
with open(path) as file_ref:
reader = csv.DictReader(file_ref)
rows = [row for row in reader]
rows = rows[-20:]
xs, ys = [], []
for row in rows:
x = np.average(
[float(row[key]) for key in loss_keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in loss_keys]
)
y = np.average(
[float(row[key]) for key in accuracy_keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in accuracy_keys]
)
xs.append(x)
ys.append(y)
# x = np.mean(xs)
# y = np.mean(ys)
# data_by_name[name]["xs"].append(x)
# data_by_name[name]["ys"].append(y)

data_by_name[name]["xs"] += xs
data_by_name[name]["ys"] += ys

return data_by_name



def get_ax(name):
if "1xC" in name:
return 0
Expand Down Expand Up @@ -694,6 +744,11 @@ def grad_tissue_fit(x, p):
return [grad_a, grad_b, grad_alpha, grad_beta, grad_E, grad_F, grad_r]


def sigmoid(x, L, x0, k, b):
o = L / (1 + np.exp(-k * (x - x0))) + b
return o


# Scipy minimize w/ Huber loss
def get_coefficients_huber(
train_xs, train_ys, fitting_func, grad_func, p0, bounds, disp: bool = True, max_iter: int = 10000
Expand Down
203 changes: 122 additions & 81 deletions scripts/scaling/final_flops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import json

import matplotlib.pyplot as plt
Expand All @@ -13,106 +14,146 @@
get_coefficients,
grad_chinchilla_flops_fit,
parse_args,
tasks,
prettify,
)

MARKERS = ["s", "P", "p", "*"]


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-k", "--keys", nargs="+", default=[], help="For avg metrics. Use one of [all-val-lm, all-bpb]"
)
parser.add_argument(
"--num_to_avg", type=int, default=1, help="Number of final ckpts to average (for final loss fitting)"
)
parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file")
parser.add_argument("-o", "--output-path", type=str, required=True, help="Path to write output figure")
args = parser.parse_args()

return args


def main():
args = parse_args()

with open(args.config_path) as f:
configs = json.load(f)
configs = {name: FinalConfig(**config) for name, config in configs.items()}

data_by_name = get_flops_data_by_name(configs, args.keys, num_to_avg=args.num_to_avg)
if len(args.keys) == 1 and args.keys[0] == "all":
args.keys = tasks.keys()

sns.set_style("whitegrid")

plt.figure(figsize=(6, 4.5))

train_fs, train_ys = [], []
for name, data in data_by_name.items():
config = configs[name]
if config.mode == "train":
train_fs += data["fs"]
train_ys += data["ys"]

# fit the parameters
# coefficients = get_coefficients_huber(
# train_fs,
# train_ys,
# chinchilla_flops_fit,
# grad_chinchilla_flops_fit,
# p0=[-3.0, 0.09, 0.1],
# bounds=[(None, None), (None, None), (None, None)],
# max_iter=100000,
# )

coefficients = get_coefficients(train_fs, train_ys, chinchilla_fit, p0=[-3.0, 0.09, 0.1])

a, b, E = coefficients

# make predictions
predicted_data_by_name = {}
plotted_predicted_data_by_name = {}
for name, data in data_by_name.items():
config = configs[name]
predicted_data_by_name[name] = {
"fs": data["fs"],
"ys": [chinchilla_flops_fit(flops, coefficients) for flops in data["fs"]],
}
fs = np.linspace(min(data["fs"]), max(data["fs"]), 100)
plotted_predicted_data_by_name[name] = {
"fs": fs,
"ys": [chinchilla_flops_fit(flops, coefficients) for flops in fs],
}

# plot the actual data
for name, data in data_by_name.items():
config = configs[name]
# plt.scatter(data["ds"], data["ys"], color="white", edgecolors=config.color, label=config.label, s=10)
for i, (f, y) in enumerate(zip(data["fs"], data["ys"])):
plt.scatter(f, y, color=config.color, marker=MARKERS[i], s=50)

predicted_data = predicted_data_by_name[name]
for f, y, y_pred in zip(data["fs"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y
plt.annotate(
f"{rel_error * 100:+.1f}%",
(f, y),
textcoords="offset points",
xytext=(6, 6),
ha="center",
fontsize=8,
num_tasks = len(args.keys)
fig, axes = plt.subplots(num_tasks, 1, figsize=(6, 4.5 * num_tasks), squeeze=False)

results = " Task Name | Actual Value | Predicted Value | Relative Error"

for i, task_name in enumerate(args.keys):
task = tasks[task_name]

data_by_name = get_flops_data_by_name(configs, task.get_loss_keys(), num_to_avg=args.num_to_avg)

train_fs, train_ys = [], []
for name, data in data_by_name.items():
config = configs[name]
if config.mode == "train":
train_fs += data["fs"]
train_ys += data["ys"]

# fit the parameters

# TODO: why does huber_loss fit not converge?
# coefficients = get_coefficients_huber(
# train_fs,
# train_ys,
# chinchilla_flops_fit,
# grad_chinchilla_flops_fit,
# p0=[-3.0, 0.09, 0.1],
# bounds=[(None, None), (None, None), (None, None)],
# max_iter=10000,
# )

# TODO: b always 0?
coefficients = get_coefficients(train_fs, train_ys, chinchilla_fit, p0=[-3.0, 0.09, 0.1])

a, b, E = coefficients

# make predictions
predicted_data_by_name = {}
plotted_predicted_data_by_name = {}
for name, data in data_by_name.items():
config = configs[name]
predicted_data_by_name[name] = {
"fs": data["fs"],
"ys": [chinchilla_fit(flops, *coefficients) for flops in data["fs"]],
}
fs = np.linspace(min(data["fs"]), max(data["fs"]), 100)
plotted_predicted_data_by_name[name] = {
"fs": fs,
"ys": [chinchilla_fit(flops, *coefficients) for flops in fs],
}


ax = axes[i][0]

# plot the actual data
for name, data in data_by_name.items():
config = configs[name]
# plt.scatter(data["ds"], data["ys"], color="white", edgecolors=config.color, label=config.label, s=10)
for i, (f, y) in enumerate(zip(data["fs"], data["ys"])):
ax.scatter(f, y, color=config.color, marker=MARKERS[i], s=50)

predicted_data = predicted_data_by_name[name]
for f, y, y_pred in zip(data["fs"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y
ax.annotate(
f"{rel_error * 100:+.1f}%",
(f, y),
textcoords="offset points",
xytext=(6, 6),
ha="center",
fontsize=8,
color=config.color,
)

if config.mode == "eval":
results += f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} | {prettify(rel_error)}"

# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
ax.plot(
data["fs"],
data["ys"],
color=config.color,
linestyle="--",
linewidth=2.0,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)

# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
plt.plot(
data["fs"],
data["ys"],
color=config.color,
linestyle="--",
linewidth=2.0,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
ax.text(
x=0.20,
y=0.25,
s=f"L(F) = {a:.2f} F ^ {b:.2f} + {E:.2f}",
fontsize=10,
transform=ax.transAxes,
)
plt.text(
x=0.20,
y=0.55,
s=f"L(F) = {a:.2f} F ^ {b:.2f} + {E:.2f}",
fontsize=10,
transform=plt.gca().transAxes,
)

plt.xscale("log")
plt.legend(loc="upper right", ncols=1, fontsize=10)
plt.xlabel("Flops (F)")
plt.ylabel("Loss")
plt.title(args.key)
plt.savefig(args.output_path, dpi=300, bbox_inches="tight")
ax.set_xscale("log")
ax.legend(loc="upper right", ncols=1, fontsize=10)
ax.set_xlabel("Flops (F)")
ax.set_ylabel("Loss")
ax.set_title(task_name)

fig.tight_layout()
fig.subplots_adjust(top=0.95)
fig.savefig(args.output_path, dpi=300)

print(results)

# y_1b_3T = chinchilla_flops_fit([1176832000, 3e12], coefficients)
# print(f"Predicted final loss for 1b-3T: {y_1b_3T:.3f}")
Expand Down

0 comments on commit 446107d

Please sign in to comment.