diff --git a/scripts/scaling/step1.py b/scripts/scaling/step1.py index c4fd37fc6..061620e89 100644 --- a/scripts/scaling/step1.py +++ b/scripts/scaling/step1.py @@ -13,6 +13,7 @@ get_coefficients_huber, grad_chinchilla_n_d_fit, grad_chinchilla_n_d_negated_fit, + get_std_errors, ) from olmo.scaling.scaling_laws.utils import ( get_final_configs, @@ -138,11 +139,29 @@ def plot_step1( task_name, fit_str, y_metric, + coefficients, + cov, ax=plt.gca(), ): # plot the fitted curve for name, data in plotted_predicted_data_by_name.items(): config = configs[name] + + if config.mode == "eval": + std_errors = get_std_errors( + [[config.n, d] for d in data["ds"]], + data["ys"], + coefficients, + cov, + chinchilla_n_d_fit, + grad_chinchilla_n_d_fit, + ) + + # Compute prediction intervals + plotted_y_lower = data["ys"] - 1.96 * std_errors + plotted_y_upper = data["ys"] + 1.96 * std_errors + # ax.fill_between(data["ds"], plotted_y_lower, plotted_y_upper, color="pink", alpha=0.3) + ax.plot( data["ds"], data["ys"], @@ -179,7 +198,7 @@ def plot_step1( color=config.color, marker="o", s=10, - label=f"{config.label} ({'predicted'})", + # label=f"{config.label} ({'predicted'})", ) ax.annotate( f"{prettify(rel_error)}", @@ -254,6 +273,8 @@ def main(): task_name, str_chinchilla_n_d_fit(coefficients), args.y_metric, + coefficients, + cov, axes[i // num_cols][i % num_cols], ) diff --git a/scripts/scaling/step1_flops.py b/scripts/scaling/step1_flops.py index d766577ea..3d1f18271 100644 --- a/scripts/scaling/step1_flops.py +++ b/scripts/scaling/step1_flops.py @@ -7,7 +7,7 @@ import numpy as np import seaborn as sns -from olmo.scaling.scaling_laws.fitting_functions import get_coefficients +from olmo.scaling.scaling_laws.fitting_functions import get_coefficients, chinchilla_flops_fit, grad_chinchilla_flops_fit, get_std_errors from olmo.scaling.scaling_laws.utils import ( get_final_configs, get_step1_data_by_name, @@ -36,7 +36,7 @@ def parse_args(): return args -def chinchilla_flops_fit(x, a, b, E): +def chinchilla_flops(x, a, b, E): # return ax**b + E return a * np.pow(x, b) + E @@ -54,7 +54,7 @@ def fit_step1(data_by_name, y_metric): coefficients, cov = get_coefficients( train_fs, train_ys, - chinchilla_flops_fit, + chinchilla_flops, p0, bounds=bounds, disp=False, @@ -76,9 +76,9 @@ def predict_step1(configs, data_by_name, coefficients, y_metric): fmax = 1.2 * max([max(data["fs"]) for data in data_by_name.values()]) if y_metric == "rc_bpb": - func = chinchilla_flops_fit + func = chinchilla_flops elif y_metric == "rc_acc": - func = chinchilla_flops_fit + func = chinchilla_flops else: raise ValueError(f"Unknown y_metric: {y_metric}") @@ -119,8 +119,34 @@ def plot_step1( task_name, fit_str, y_metric, + coefficients, + cov, ax=plt.gca(), ): + + fmin = min(min(data["fs"]) for data in plotted_predicted_data_by_name.values()) + fmax = max(max(data["fs"]) for data in plotted_predicted_data_by_name.values()) + fs = np.linspace(fmin, fmax, 100) + plotted_predicted_data = { + "fs": fs, + "ys": [chinchilla_flops(f, *coefficients) for f in fs], + } + + std_errors = get_std_errors( + plotted_predicted_data["fs"], + plotted_predicted_data["ys"], + coefficients, + cov, + chinchilla_flops_fit, + grad_chinchilla_flops_fit, + ) + + # Compute prediction intervals + plotted_y_lower = plotted_predicted_data["ys"] - 1.96 * std_errors + plotted_y_upper = plotted_predicted_data["ys"] + 1.96 * std_errors + + # ax.fill_between(plotted_predicted_data["fs"], plotted_y_lower, plotted_y_upper, color="pink", alpha=0.3) + # plot the fitted curve for name, data in plotted_predicted_data_by_name.items(): config = configs[name] @@ -235,6 +261,8 @@ def main(): task_name, str_chinchilla_flops_fit(coefficients), args.y_metric, + coefficients, + cov, axes[i // num_cols][i % num_cols], )