From 446107db572272baa33f867ee867fb26722d16db Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 11 Nov 2024 14:47:36 -0800 Subject: [PATCH] results for multiple tasks --- olmo/scaling/scaling_laws/utils.py | 55 ++++++++ scripts/scaling/final_flops.py | 203 +++++++++++++++++------------ 2 files changed, 177 insertions(+), 81 deletions(-) diff --git a/olmo/scaling/scaling_laws/utils.py b/olmo/scaling/scaling_laws/utils.py index 97d2cbc33..eed3eb2c9 100644 --- a/olmo/scaling/scaling_laws/utils.py +++ b/olmo/scaling/scaling_laws/utils.py @@ -120,6 +120,13 @@ class DownstreamTaskPrediction: task_minimum: float = 0.25 task_maximum: float = 1.0 + def get_loss_keys(self): + return self.task_loss_key if isinstance(self.task_loss_key, list) else [self.task_loss_key] + + def get_accuracy_keys(self): + return self.task_accuracy_key if isinstance(self.task_accuracy_key, list) else [self.task_accuracy_key] + + downstream_5_shot: Dict[str, DownstreamTaskPrediction] = { f"{key}_rc_5shot": DownstreamTaskPrediction( @@ -294,6 +301,13 @@ def get_accuracy_keys(tasks: Dict[str, DownstreamTaskPrediction]) -> List[str]: KEYS_BY_KEY[task_name] = task.task_loss_key if isinstance(task.task_loss_key, list) else [task.task_loss_key] +def prettify(rel_error, is_percentage=True): + if is_percentage: + return f"{rel_error * 100:+.1f}%" + else: + return f"{rel_error:.2f}" + + def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -420,6 +434,42 @@ def get_flops_data_by_name(configs, keys, num_to_avg=1): return data_by_name + +def get_downstream_data_by_name(configs, keys, num_to_avg=-1): + # TODO: weight_by_key may not be working correctly for mmlu + loss_keys = tasks[keys].get_loss_keys() + accuracy_keys = tasks[keys].get_accuracy_keys() + data_by_name: Dict = defaultdict(lambda: {"xs": [], "ys": []}) + + for name, config in configs.items(): + n = config.n + for path in config.paths: + with open(path) as file_ref: + reader = csv.DictReader(file_ref) + rows = [row for row in reader] + rows = rows[-20:] + xs, ys = [], [] + for row in rows: + x = np.average( + [float(row[key]) for key in loss_keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in loss_keys] + ) + y = np.average( + [float(row[key]) for key in accuracy_keys], weights=[WEIGHT_BY_KEY.get(key, 1.0) for key in accuracy_keys] + ) + xs.append(x) + ys.append(y) + # x = np.mean(xs) + # y = np.mean(ys) + # data_by_name[name]["xs"].append(x) + # data_by_name[name]["ys"].append(y) + + data_by_name[name]["xs"] += xs + data_by_name[name]["ys"] += ys + + return data_by_name + + + def get_ax(name): if "1xC" in name: return 0 @@ -694,6 +744,11 @@ def grad_tissue_fit(x, p): return [grad_a, grad_b, grad_alpha, grad_beta, grad_E, grad_F, grad_r] +def sigmoid(x, L, x0, k, b): + o = L / (1 + np.exp(-k * (x - x0))) + b + return o + + # Scipy minimize w/ Huber loss def get_coefficients_huber( train_xs, train_ys, fitting_func, grad_func, p0, bounds, disp: bool = True, max_iter: int = 10000 diff --git a/scripts/scaling/final_flops.py b/scripts/scaling/final_flops.py index 09d9951cd..bad34ac9a 100644 --- a/scripts/scaling/final_flops.py +++ b/scripts/scaling/final_flops.py @@ -1,3 +1,4 @@ +import argparse import json import matplotlib.pyplot as plt @@ -13,11 +14,28 @@ get_coefficients, grad_chinchilla_flops_fit, parse_args, + tasks, + prettify, ) MARKERS = ["s", "P", "p", "*"] +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-k", "--keys", nargs="+", default=[], help="For avg metrics. Use one of [all-val-lm, all-bpb]" + ) + parser.add_argument( + "--num_to_avg", type=int, default=1, help="Number of final ckpts to average (for final loss fitting)" + ) + parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file") + parser.add_argument("-o", "--output-path", type=str, required=True, help="Path to write output figure") + args = parser.parse_args() + + return args + + def main(): args = parse_args() @@ -25,94 +43,117 @@ def main(): configs = json.load(f) configs = {name: FinalConfig(**config) for name, config in configs.items()} - data_by_name = get_flops_data_by_name(configs, args.keys, num_to_avg=args.num_to_avg) + if len(args.keys) == 1 and args.keys[0] == "all": + args.keys = tasks.keys() sns.set_style("whitegrid") - plt.figure(figsize=(6, 4.5)) - - train_fs, train_ys = [], [] - for name, data in data_by_name.items(): - config = configs[name] - if config.mode == "train": - train_fs += data["fs"] - train_ys += data["ys"] - - # fit the parameters - # coefficients = get_coefficients_huber( - # train_fs, - # train_ys, - # chinchilla_flops_fit, - # grad_chinchilla_flops_fit, - # p0=[-3.0, 0.09, 0.1], - # bounds=[(None, None), (None, None), (None, None)], - # max_iter=100000, - # ) - - coefficients = get_coefficients(train_fs, train_ys, chinchilla_fit, p0=[-3.0, 0.09, 0.1]) - - a, b, E = coefficients - - # make predictions - predicted_data_by_name = {} - plotted_predicted_data_by_name = {} - for name, data in data_by_name.items(): - config = configs[name] - predicted_data_by_name[name] = { - "fs": data["fs"], - "ys": [chinchilla_flops_fit(flops, coefficients) for flops in data["fs"]], - } - fs = np.linspace(min(data["fs"]), max(data["fs"]), 100) - plotted_predicted_data_by_name[name] = { - "fs": fs, - "ys": [chinchilla_flops_fit(flops, coefficients) for flops in fs], - } - - # plot the actual data - for name, data in data_by_name.items(): - config = configs[name] - # plt.scatter(data["ds"], data["ys"], color="white", edgecolors=config.color, label=config.label, s=10) - for i, (f, y) in enumerate(zip(data["fs"], data["ys"])): - plt.scatter(f, y, color=config.color, marker=MARKERS[i], s=50) - - predicted_data = predicted_data_by_name[name] - for f, y, y_pred in zip(data["fs"], data["ys"], predicted_data["ys"]): - rel_error = (y_pred - y) / y - plt.annotate( - f"{rel_error * 100:+.1f}%", - (f, y), - textcoords="offset points", - xytext=(6, 6), - ha="center", - fontsize=8, + num_tasks = len(args.keys) + fig, axes = plt.subplots(num_tasks, 1, figsize=(6, 4.5 * num_tasks), squeeze=False) + + results = " Task Name | Actual Value | Predicted Value | Relative Error" + + for i, task_name in enumerate(args.keys): + task = tasks[task_name] + + data_by_name = get_flops_data_by_name(configs, task.get_loss_keys(), num_to_avg=args.num_to_avg) + + train_fs, train_ys = [], [] + for name, data in data_by_name.items(): + config = configs[name] + if config.mode == "train": + train_fs += data["fs"] + train_ys += data["ys"] + + # fit the parameters + + # TODO: why does huber_loss fit not converge? + # coefficients = get_coefficients_huber( + # train_fs, + # train_ys, + # chinchilla_flops_fit, + # grad_chinchilla_flops_fit, + # p0=[-3.0, 0.09, 0.1], + # bounds=[(None, None), (None, None), (None, None)], + # max_iter=10000, + # ) + + # TODO: b always 0? + coefficients = get_coefficients(train_fs, train_ys, chinchilla_fit, p0=[-3.0, 0.09, 0.1]) + + a, b, E = coefficients + + # make predictions + predicted_data_by_name = {} + plotted_predicted_data_by_name = {} + for name, data in data_by_name.items(): + config = configs[name] + predicted_data_by_name[name] = { + "fs": data["fs"], + "ys": [chinchilla_fit(flops, *coefficients) for flops in data["fs"]], + } + fs = np.linspace(min(data["fs"]), max(data["fs"]), 100) + plotted_predicted_data_by_name[name] = { + "fs": fs, + "ys": [chinchilla_fit(flops, *coefficients) for flops in fs], + } + + + ax = axes[i][0] + + # plot the actual data + for name, data in data_by_name.items(): + config = configs[name] + # plt.scatter(data["ds"], data["ys"], color="white", edgecolors=config.color, label=config.label, s=10) + for i, (f, y) in enumerate(zip(data["fs"], data["ys"])): + ax.scatter(f, y, color=config.color, marker=MARKERS[i], s=50) + + predicted_data = predicted_data_by_name[name] + for f, y, y_pred in zip(data["fs"], data["ys"], predicted_data["ys"]): + rel_error = (y_pred - y) / y + ax.annotate( + f"{rel_error * 100:+.1f}%", + (f, y), + textcoords="offset points", + xytext=(6, 6), + ha="center", + fontsize=8, + color=config.color, + ) + + if config.mode == "eval": + results += f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} | {prettify(rel_error)}" + + # plot the fitted curve + for name, data in plotted_predicted_data_by_name.items(): + config = configs[name] + ax.plot( + data["fs"], + data["ys"], color=config.color, + linestyle="--", + linewidth=2.0, + label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})', ) - - # plot the fitted curve - for name, data in plotted_predicted_data_by_name.items(): - config = configs[name] - plt.plot( - data["fs"], - data["ys"], - color=config.color, - linestyle="--", - linewidth=2.0, - label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})', + ax.text( + x=0.20, + y=0.25, + s=f"L(F) = {a:.2f} F ^ {b:.2f} + {E:.2f}", + fontsize=10, + transform=ax.transAxes, ) - plt.text( - x=0.20, - y=0.55, - s=f"L(F) = {a:.2f} F ^ {b:.2f} + {E:.2f}", - fontsize=10, - transform=plt.gca().transAxes, - ) - plt.xscale("log") - plt.legend(loc="upper right", ncols=1, fontsize=10) - plt.xlabel("Flops (F)") - plt.ylabel("Loss") - plt.title(args.key) - plt.savefig(args.output_path, dpi=300, bbox_inches="tight") + ax.set_xscale("log") + ax.legend(loc="upper right", ncols=1, fontsize=10) + ax.set_xlabel("Flops (F)") + ax.set_ylabel("Loss") + ax.set_title(task_name) + + fig.tight_layout() + fig.subplots_adjust(top=0.95) + fig.savefig(args.output_path, dpi=300) + + print(results) # y_1b_3T = chinchilla_flops_fit([1176832000, 3e12], coefficients) # print(f"Predicted final loss for 1b-3T: {y_1b_3T:.3f}")