Skip to content

Commit

Permalink
mark target and pred clearly
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Nov 20, 2024
1 parent 36539f7 commit 74e44bd
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 30 deletions.
37 changes: 23 additions & 14 deletions scripts/scaling/step1.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ def plot_step1(
y_metric,
ax=plt.gca(),
):
# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
ax.plot(
data["ds"],
data["ys"],
color=config.color,
linestyle="--",
linewidth=1.5,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)

# plot the actual and predicted data
unsigned_rel_errors = []
for name, data in data_by_name.items():
Expand All @@ -151,15 +163,24 @@ def plot_step1(
d,
y,
color=config.color,
marker=MARKERS[i] if config.mode == "train" else "o",
s=50,
marker=MARKERS[i] if config.mode == "train" else "x",
s=50 if config.mode == "train" else 10,
label=f"{config.label} (target)" if config.mode == "eval" else None,
)

for d, y, y_pred in zip(data["ds"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y
if config.mode == "train":
unsigned_rel_errors.append(np.abs(rel_error))
else:
ax.scatter(
d,
y_pred,
color=config.color,
marker="o",
s=10,
label=f"{config.label} ({'predicted'})",
)
ax.annotate(
f"{prettify(rel_error)}",
(d, y),
Expand All @@ -172,18 +193,6 @@ def plot_step1(
)
avg_unsigned_rel_error = np.mean(unsigned_rel_errors)

# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
ax.plot(
data["ds"],
data["ys"],
color=config.color,
linestyle="--",
linewidth=1.5,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)

ax.set_xscale("log")
ax.legend(loc="upper right", ncols=1, fontsize=8)
ax.set_xlabel("Tokens (D)")
Expand Down
37 changes: 23 additions & 14 deletions scripts/scaling/step1_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ def plot_step1(
y_metric,
ax=plt.gca(),
):
# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
ax.plot(
data["fs"],
data["ys"],
color=config.color,
linestyle="--",
linewidth=1.5,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)

# plot the actual and predicted data
unsigned_rel_errors = []
for name, data in data_by_name.items():
Expand All @@ -132,15 +144,24 @@ def plot_step1(
f,
y,
color=config.color,
marker=MARKERS[i] if config.mode == "train" else "o",
s=50,
marker=MARKERS[i] if config.mode == "train" else "x",
s=50 if config.mode == "train" else 10,
label=f"{config.label} (target)" if config.mode == "eval" else None,
)

for f, y, y_pred in zip(data["fs"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y
if config.mode == "train":
unsigned_rel_errors.append(np.abs(rel_error))
else:
ax.scatter(
f,
y_pred,
color=config.color,
marker="o",
s=10,
label=f"{config.label} ({'predicted'})",
)
ax.annotate(
f"{prettify(rel_error)}",
(f, y),
Expand All @@ -153,18 +174,6 @@ def plot_step1(
)
avg_unsigned_rel_error = np.mean(unsigned_rel_errors)

# plot the fitted curve
for name, data in plotted_predicted_data_by_name.items():
config = configs[name]
ax.plot(
data["fs"],
data["ys"],
color=config.color,
linestyle="--",
linewidth=1.5,
label=f'{config.label} ({"fitted" if config.mode == "train" else "predicted"})',
)

ax.set_xscale("log")
ax.legend(loc="upper right", ncols=1, fontsize=8)
ax.set_xlabel("Flops (F)")
Expand Down
12 changes: 10 additions & 2 deletions scripts/scaling/step2.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,24 @@ def plot_step2(
data["xs"],
data["ys"],
color=config.color,
marker="o",
marker="o" if config.mode == "train" else "x",
s=10,
label=f"{config.label} ({'fitted' if config.mode == 'train' else 'predicted'})",
label=f"{config.label} ({'fitted' if config.mode == 'train' else 'target'})",
)
for x, y, y_pred in zip(data["xs"], data["ys"], predicted_data["ys"]):
rel_error = (y_pred - y) / y

if config.mode == "train":
unsigned_rel_errs.append(abs(rel_error))
else:
ax.scatter(
x,
y_pred,
color=config.color,
marker="o",
s=10,
label=f"{config.label} ({'predicted'})",
)
ax.annotate(
f"{np.abs(rel_error) * 100:.1f}%",
(x, y),
Expand Down

0 comments on commit 74e44bd

Please sign in to comment.