diff --git a/scripts/scaling/predict.py b/scripts/scaling/predict.py index 4d7597c3d..5aac0a81a 100644 --- a/scripts/scaling/predict.py +++ b/scripts/scaling/predict.py @@ -1,13 +1,15 @@ -# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 6887575552 -d 3945065873408 -t 7B-4T --skip_perc 0.1 --moving_avg 5 -# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 13202396160 -d 5000088518656 -t 13B-5T --skip_perc 0.1 --moving_avg 5 -# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 6887575552 -d 3945065873408 -t 7B-4T --skip_perc 0.1 --moving_avg 5 --x_metric c4 -# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -n 13202396160 -d 5000088518656 -t 13B-5T --skip_perc 0.1 --moving_avg 5 --x_metric c4 -# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 6887575552 -d 3945065873408 -t 7B-4T --moving_avg 5 -# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -y mc_acc -n 13202396160 -d 5000088518656 -t 13B-5T --moving_avg 5 +# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -o figure/peteish-moreeval/chained_main.pdf -n 6887575552 -d 3945065873408 -t 7B-4T --skip_perc 0.1 --moving_avg 5 +# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -o figure/peteish-moreeval/chained_main.pdf -n 13202396160 -d 5000088518656 -t 13B-5T --skip_perc 0.1 --moving_avg 5 +# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -o figure/peteish-moreeval/chained_c4_main.pdf -n 6887575552 -d 3945065873408 -t 7B-4T --skip_perc 0.1 --moving_avg 5 --x_metric c4 +# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2.json -o figure/peteish-moreeval/chained_c4_main.pdf -n 13202396160 -d 5000088518656 -t 13B-5T --skip_perc 0.1 --moving_avg 5 --x_metric c4 +# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -o figure/peteish-moreeval/chained_mc_main.pdf -y mc_acc -n 6887575552 -d 3945065873408 -t 7B-4T --moving_avg 5 +# python scripts/scaling/predict.py -k v2_main -c scripts/scaling/final.json --step2-config-path scripts/scaling/step2_mc.json -o figure/peteish-moreeval/chained_mc_main.pdf -y mc_acc -n 13202396160 -d 5000088518656 -t 13B-5T --moving_avg 5 import argparse import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns from step1 import fit_step1 from step2 import fit_step2 @@ -17,8 +19,12 @@ get_step1_data_by_name, get_step2_data_by_name, get_task_sets, + tasks, ) +MARKERS = ["s", "P", "p", "*", "o"] +FONTSIZE = 9 + def parse_args(): parser = argparse.ArgumentParser() @@ -38,6 +44,7 @@ def parse_args(): ) parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file") parser.add_argument("--step2-config-path", type=str, default=None, help="Path to config file for step2") + parser.add_argument("-o", "--output-path", type=str, required=True, help="Path to write output figure") parser.add_argument("-n", "--n", type=int, required=True, help="Model size of the target model") parser.add_argument("-d", "--d", type=int, required=True, help="Data size of the target model") parser.add_argument( @@ -51,6 +58,116 @@ def parse_args(): return args +def predict_chained(data_by_name, step1_coefficients, step2_coefficients): + predicted_data_by_name = {} + plotted_predicted_data_by_name = {} + + dmin = 0.8 * min([min(data["ds"]) for data in data_by_name.values()]) + dmax = 1.5 * max([max(data["ds"]) for data in data_by_name.values()]) + + for name, data in data_by_name.items(): + predicted_data_by_name[name] = { + "ds": data["ds"], + "ys": [sigmoid(chinchilla_n_d_fit([n, d], step1_coefficients), *step2_coefficients) for n, d in zip(data["ns"], data["ds"])], + } + ds = np.exp(np.linspace(np.log(dmin), np.log(dmax), 100)) + ns = [data["ns"][0]] * len(ds) + plotted_predicted_data_by_name[name] = { + "ds": ds, + "ys": [sigmoid(chinchilla_n_d_fit([n, d], step1_coefficients), *step2_coefficients) for n, d in zip(ns, ds)], + } + + if data["mode"] == "eval": + predicted_data = predicted_data_by_name[name] + for d, y, y_pred in zip(data["ds"], data["ys"], predicted_data["ys"]): + rel_error = (y_pred - y) / y + + return predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) + + +def str_chained_fit(step1_coefficients, step2_coefficients): + a, b, alpha, beta, E = step1_coefficients + A, B = np.exp(a), np.exp(b) + a, x0, k, b = step2_coefficients + return ( + f"L(N, D) = {A:.2f} / N^{alpha:.2f} + {B:.2f} / D^{beta:.2f} + {E:.2f}; Acc(L) = {a:.2f} / (1 + e^(-{k:.2f}(L - {x0:.2f}))) + {b:.2f}" + ) + + +def plot_chained( + configs, + data_by_name, + predicted_data_by_name, + plotted_predicted_data_by_name, + task_name, + fit_str, + ax=plt.gca(), +): + # plot the fitted curve + for name, data in plotted_predicted_data_by_name.items(): + config = configs[name] + ax.plot( + data["ds"], + data["ys"], + color=config.color, + linestyle="--", + alpha=0.7, + linewidth=1.5, + label=f'{config.label} (fitted)' if config.mode == "train" else None, + ) + + # plot the actual and predicted data + num_eval_annotation = 0 + for name, data in data_by_name.items(): + config = configs[name] + predicted_data = predicted_data_by_name[name] + + for i, (d, y) in enumerate(zip(data["ds"], data["ys"])): + ax.scatter( + d, + y, + color=config.color, + marker=MARKERS[i] if config.mode == "train" else "o", + s=50 if config.mode == "train" else 20, + label=f"{config.label} (target)" if config.mode == "eval" else None, + ) + + for d, y, y_pred in zip(data["ds"], data["ys"], predicted_data["ys"]): + rel_error = (y_pred - y) / y + if config.mode == "train": + pass + else: + ax.scatter( + d, + y_pred, + color=config.color, + marker="x", + s=20, + label=f"{config.label} (predicted)", + ) + ax.annotate( + f"{abs(rel_error * 100):.1f}%", + (d, y_pred), + textcoords="offset points", + xytext=(10, -5 + 10*num_eval_annotation), + ha="left", + va="bottom", + fontsize=FONTSIZE, + color=config.color, + ) + num_eval_annotation += 1 + + ax.set_xscale("log") + ax.legend(loc="upper right", ncols=1, fontsize=FONTSIZE) + ax.set_xlabel("Tokens (D)", fontsize=FONTSIZE) + ax.set_ylabel("Task RC accuracy", fontsize=FONTSIZE) + ax.set_title( + f"{tasks[task_name].display_name}", + fontsize=FONTSIZE, + fontweight="bold", + ) + + def main(): args = parse_args() configs = get_final_configs(args.config_path) @@ -59,16 +176,18 @@ def main(): else: step2_configs = configs + sns.set_style("whitegrid") + num_tasks = len(args.keys) + num_cols = min(4, num_tasks) + num_rows = (num_tasks + num_cols - 1) // num_cols + fig, axes = plt.subplots(num_rows, num_cols, figsize=(2.75 * num_cols, 2.25 * num_rows), squeeze=False) + results = "Task Name | Prediction | Actual | Rel Error" for r, task_name in enumerate(args.keys): - # Step 1 step1_data_by_name = get_step1_data_by_name( configs, task_name, y_metric=args.x_metric, moving_avg=args.moving_avg ) - step1_coefficients, _ = fit_step1(step1_data_by_name, y_metric=args.x_metric) - - # Step 2 step2_data_by_name = get_step2_data_by_name( step2_configs, task_name, @@ -77,8 +196,29 @@ def main(): moving_avg=args.moving_avg, skip_perc=args.skip_perc, ) + single_step_data_by_name = get_step1_data_by_name( + configs, task_name, y_metric="rc_acc", moving_avg=args.moving_avg + ) + + # fit the parameters + step1_coefficients, _ = fit_step1(step1_data_by_name, y_metric=args.x_metric) step2_coefficients, _ = fit_step2(step2_data_by_name, task_name, args.y_metric, args.use_log_sigmoid) + # make predictions + predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_chained( + single_step_data_by_name, step1_coefficients, step2_coefficients + ) + + plot_chained( + configs, + single_step_data_by_name, + predicted_data_by_name, + plotted_predicted_data_by_name, + task_name, + str_chained_fit(step1_coefficients, step2_coefficients), + axes[r // num_cols][r % num_cols], + ) + # make predictions pred_loss = chinchilla_n_d_fit([args.n, args.d], step1_coefficients) fit_fn = log_sigmoid if args.use_log_sigmoid else sigmoid @@ -91,6 +231,26 @@ def main(): else: results += f"\n{task_name} | {pred_acc * 100:.1f} | - | -" + handles, labels = axes[-1][-1].get_legend_handles_labels() + # delete x-axis labels for all but the bottom row + for i in range(num_cols): + for j in range(num_rows): + if j != num_rows - 1: + axes[j][i].set_xlabel("") + if i != 0: + axes[j][i].set_ylabel("") + + axes[j][i].legend().remove() + + fig.tight_layout(w_pad=0.01) + legend = fig.legend(handles, labels, loc='upper center', + ncol=10, fontsize=FONTSIZE, bbox_to_anchor=(0.5, 1.07), + handletextpad=0.3, columnspacing=0.7) + for handle in legend.legend_handles: + handle.set_alpha(1.0) + + fig.savefig(args.output_path, dpi=300, bbox_inches='tight') + print(results) diff --git a/scripts/scaling/single_step.py b/scripts/scaling/single_step.py index 42ec71bd9..5315a2b5e 100644 --- a/scripts/scaling/single_step.py +++ b/scripts/scaling/single_step.py @@ -36,7 +36,7 @@ def parse_args(): return args -def fit_step12(data_by_name, task_name): +def fit_single_step(data_by_name, task_name): train_nds, train_ys = [], [] for name, data in data_by_name.items(): if data["mode"] == "train": @@ -59,7 +59,7 @@ def fit_step12(data_by_name, task_name): return coefficients -def predict_step12(data_by_name, coefficients): +def predict_single_step(data_by_name, coefficients): predicted_data_by_name = {} plotted_predicted_data_by_name = {} @@ -94,7 +94,7 @@ def str_combined_fit(coefficients): ) -def plot_step12( +def plot_single_step( configs, data_by_name, predicted_data_by_name, @@ -162,7 +162,7 @@ def plot_step12( ax.set_xscale("log") ax.legend(loc="upper right", ncols=1, fontsize=FONTSIZE) ax.set_xlabel("Tokens (D)", fontsize=FONTSIZE) - ax.set_ylabel("Task accuracy", fontsize=FONTSIZE) + ax.set_ylabel("Task RC accuracy", fontsize=FONTSIZE) ax.set_title( f"{tasks[task_name].display_name} ({avg_unsigned_rel_error * 100:.2f}%)", fontsize=FONTSIZE, @@ -186,15 +186,15 @@ def main(): data_by_name = get_step1_data_by_name(configs, task_name, y_metric="rc_acc", moving_avg=args.moving_avg) # fit the parameters - coefficients = fit_step12(data_by_name, task_name) + coefficients = fit_single_step(data_by_name, task_name) # make predictions - predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_step12( + predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_single_step( data_by_name, coefficients ) results += f"\n{task_name} | {prettify(y, False)} | {prettify(y_pred, False)} | {prettify(rel_error)}" - plot_step12( + plot_single_step( configs, data_by_name, predicted_data_by_name,