Skip to content

Commit

Permalink
Update single-step prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
liujch1998 committed Nov 21, 2024
1 parent e88233a commit 1233662
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 60 deletions.
12 changes: 5 additions & 7 deletions olmo/scaling/scaling_laws/fitting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,25 @@ def grad_chinchilla_n_d_fit(x, p):

# x[0] = n, x[1] = d
# p[0] = a = log(A), p[1] = b = log(B), p[2] = alpha, p[3] = beta, p[4] = E
# p[5] = p, p[6] = L0, p[7] = k, p[8] = q
# p[5] = p, p[6] = q
def combined_fit(x, p):
step1 = np.exp(p[0]) / x[0] ** p[2] + np.exp(p[1]) / x[1] ** p[3] + p[4]
step2 = p[5] / (1 + np.exp(-p[7] * (step1 - p[6]))) + p[8]
step2 = p[5] / (1 + np.exp(-step1)) + p[6]
step2 = max(1e-6, step2)
return step2


def grad_combined_fit(x, p):
step1 = np.exp(p[0]) / x[0] ** p[2] + np.exp(p[1]) / x[1] ** p[3] + p[4]
grad_p = 1 / (1 + np.exp(-p[7] * (step1 - p[6])))
grad_k = p[5] * grad_p * (1 - grad_p) * (step1 - p[6])
grad_L0 = p[5] * grad_p * (1 - grad_p) * (-p[7])
grad_step1 = p[5] * grad_p * (1 - grad_p) * p[7]
grad_p = 1 / (1 + np.exp(-step1))
grad_step1 = p[5] * grad_p * (1 - grad_p)
grad_q = 1
grad_a = grad_step1 * np.exp(p[0]) / x[0] ** p[2]
grad_b = grad_step1 * np.exp(p[1]) / x[1] ** p[3]
grad_alpha = grad_step1 * np.exp(p[0]) * (-np.log(x[0])) / x[0] ** p[2]
grad_beta = grad_step1 * np.exp(p[1]) * (-np.log(x[1])) / x[1] ** p[3]
grad_E = grad_step1 * 1
return [grad_a, grad_b, grad_alpha, grad_beta, grad_E, grad_p, grad_k, grad_L0, grad_q]
return [grad_a, grad_b, grad_alpha, grad_beta, grad_E, grad_p, grad_q]


# x[0] = n, x[1] = d
Expand Down
2 changes: 1 addition & 1 deletion olmo/scaling/scaling_laws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_mc_accuracy_keys(self):
"socialiqa",
"winogrande",
]
core_small_names = ["hellaswag", "arc_easy", "arc_challenge", "piqa", "csqa", "socialiqa", "openbookqa"]
core_small_names = ["hellaswag", "arc_challenge", "arc_easy", "piqa", "csqa", "socialiqa", "openbookqa"]
mmlu_names = ["mmlu_stem", "mmlu_humanities", "mmlu_social_sciences", "mmlu_other"]

core_5shot_tasks: Dict[str, DownstreamTaskPrediction] = {
Expand Down
12 changes: 0 additions & 12 deletions scripts/ladder_peteish.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,18 +452,6 @@ def flops_for_model(model_config: Union[ModelConfig, str]) -> int:
def flops_cmd(args: argparse.Namespace):
cfg = config_from_args(args)

from tqdm import tqdm

from olmo.eval import build_evaluator
from olmo.tokenizer import Tokenizer

device = torch.device("cpu")
tokenizer = Tokenizer.from_train_config(cfg)
for eval_cfg in tqdm(cfg.evaluators):
evaluator = build_evaluator(cfg, eval_cfg, tokenizer, device)
print(evaluator)
exit()

flops = flops_for_model(cfg.model)
length_in_tokens = parse_length(args.length, parse_size(args.model))
print("Expected model flops: ", round(flops * length_in_tokens / 1e18, 3), "x 10^9 GFlops")
Expand Down
44 changes: 21 additions & 23 deletions scripts/ladder_peteish.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,30 @@
./scripts/beaker/ladder_peteish-launch.sh 8 --model 760M --data olmoe-mix-0924 --length 10xC --name peteish --save_overwrite --device_batch_size 2 --batch_size_divisor 128
./scripts/beaker/ladder_peteish-launch.sh 16 --model 1B --data olmoe-mix-0924 --length 10xC --name peteish --save_overwrite --device_batch_size 1 --batch_size_divisor 128


./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 4 --batch_size_divisor 128 --alpha_f 1.0
./scripts/beaker/ladder_peteish-launch.sh 8 --model 370M --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 2 --batch_size_divisor 128 --alpha_f 1.0
./scripts/beaker/ladder_peteish-launch.sh 8 --model 600M --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 2 --batch_size_divisor 128 --alpha_f 1.0
./scripts/beaker/ladder_peteish-launch.sh 8 --model 760M --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 2 --batch_size_divisor 128 --alpha_f 1.0
./scripts/beaker/ladder_peteish-launch.sh 16 --model 1B --data olmoe-mix-0924 --length 10xC --name peteish-const --save_overwrite --device_batch_size 1 --batch_size_divisor 128 --alpha_f 1.0

./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 4 --model 370M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
# ./scripts/beaker/ladder_peteish-launch.sh 4 --model 600M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64

./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 4 --model 370M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
# ./scripts/beaker/ladder_peteish-launch.sh 4 --model 600M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64

./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 4 --model 370M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
# ./scripts/beaker/ladder_peteish-launch.sh 4 --model 600M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64

./scripts/beaker/ladder_peteish-launch.sh 4 --model 190M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 4 --model 370M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
# ./scripts/beaker/ladder_peteish-launch.sh 4 --model 600M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64
./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64

./scripts/beaker/ladder_peteish-launch.sh 2 --model 190M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
./scripts/beaker/ladder_peteish-launch.sh 2 --model 370M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64 --device_eval_batch_size 8
./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 1xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64 --device_eval_batch_size 4

./scripts/beaker/ladder_peteish-launch.sh 2 --model 190M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
./scripts/beaker/ladder_peteish-launch.sh 2 --model 370M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64 --device_eval_batch_size 8
./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 2xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64 --device_eval_batch_size 4

./scripts/beaker/ladder_peteish-launch.sh 2 --model 190M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
./scripts/beaker/ladder_peteish-launch.sh 2 --model 370M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64 --device_eval_batch_size 8
./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 5xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64 --device_eval_batch_size 4

./scripts/beaker/ladder_peteish-launch.sh 2 --model 190M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
./scripts/beaker/ladder_peteish-launch.sh 2 --model 370M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 4 --batch_size_divisor 64 --device_eval_batch_size 16
./scripts/beaker/ladder_peteish-launch.sh 4 --model 760M --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 2 --batch_size_divisor 64 --device_eval_batch_size 8
./scripts/beaker/ladder_peteish-launch.sh 8 --model 1B --data olmoe-mix-0924 --length 10xC --name peteish-moreeval --save_overwrite --device_batch_size 1 --batch_size_divisor 64 --device_eval_batch_size 4
31 changes: 14 additions & 17 deletions scripts/scaling/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from olmo.scaling.scaling_laws.utils import (
get_final_configs,
get_final_data_by_name,
get_step1_data_by_name,
get_task_sets,
prettify,
tasks,
Expand All @@ -23,25 +23,25 @@
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-k", "--keys", nargs="+", default=[], help="Key(s) for tasks")
parser.add_argument(
"--num_to_avg", type=int, default=1, help="Number of final ckpts to average (for final loss fitting)"
)
parser.add_argument("--moving_avg", type=int, default=1, help="Moving average for bpb loss")
parser.add_argument("-c", "--config-path", type=str, required=True, help="Path to config file")
parser.add_argument("-o", "--output-path", type=str, required=True, help="Path to write output figure")
args = parser.parse_args()

args.keys = get_task_sets(args.keys)

return args


def fit_step12(data_by_name):
def fit_step12(data_by_name, task_name):
train_nds, train_ys = [], []
for name, data in data_by_name.items():
if data["mode"] == "train":
train_nds += [[n, d] for n, d in zip(data["ns"], data["ds"])]
train_ys += data["ys"]

p0 = [3.0, 5.0, 0.2, 0.3, 1.0, -0.5, 1.0, 3.0, 1.0]
bounds = [(0, None), (0, None), (0, 1), (0, 1), (0, None), (-0.9999, 0), (0, None), (0, None), (0, 1)]
p0 = [3.0, 5.0, 0.2, 0.3, 0.0, tasks[task_name].task_minimum - 1.0, 1.0]
bounds = [(0, 10), (0, 10), (0, 1), (0, 1), (-10, 10), (-0.9999, 0), (0, 1)]
coefficients = get_coefficients_huber(
train_nds,
train_ys,
Expand Down Expand Up @@ -84,9 +84,11 @@ def predict_step12(data_by_name, coefficients):


def str_combined_fit(coefficients):
a, b, alpha, beta, E, p, L0, k, q = coefficients
a, b, alpha, beta, E, p, q = coefficients
A, B = np.exp(a), np.exp(b)
return f"L(N, D) = {A:.2f} / N^{alpha:.2f} + {B:.2f} / D^{beta:.2f} + {E:.2f}\nAcc(L) = {p:.2f} / (1 + e^(-{k:.2f} (L - {L0:.2f}))) + {q:.2f}"
return (
f"Acc(N, D) = {p:.2f} / (1 + e^-({A:.2f} / N^{alpha:.2f} \n + {B:.2f} / D^{beta:.2f} + {E:.2f})) + {q:.2f}"
)


def plot_step12(
Expand Down Expand Up @@ -143,7 +145,7 @@ def plot_step12(
)

ax.set_xscale("log")
ax.legend(loc="lower right", ncols=1, fontsize=8)
ax.legend(ncols=1, fontsize=7)
ax.set_xlabel("Tokens (D)")
ax.set_ylabel("Task accuracy")
ax.set_title(
Expand All @@ -154,11 +156,8 @@ def plot_step12(

def main():
args = parse_args()

configs = get_final_configs(args.config_path)

args.keys = get_task_sets(args.keys)

sns.set_style("whitegrid")
num_tasks = len(args.keys)
num_cols = min(4, num_tasks)
Expand All @@ -168,12 +167,10 @@ def main():
results = "Task Name | Actual Value | Predicted Value | Relative Error"

for i, task_name in enumerate(args.keys):
task = tasks[task_name]
keys = task.get_accuracy_keys()
data_by_name = get_final_data_by_name(configs, keys, num_to_avg=args.num_to_avg)
data_by_name = get_step1_data_by_name(configs, task_name, y_metric="rc_acc", moving_avg=args.moving_avg)

# fit the parameters
coefficients = fit_step12(data_by_name)
coefficients = fit_step12(data_by_name, task_name)

# make predictions
predicted_data_by_name, plotted_predicted_data_by_name, (y, y_pred, rel_error) = predict_step12(
Expand Down

0 comments on commit 1233662

Please sign in to comment.