Skip to content

Commit

Permalink
fix sgd squared_error loss name and default batch size in mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrenodet committed Feb 21, 2025
1 parent bf6633e commit 735dc7b
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions mislabeled/probe/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ def linearize_linear_model_sgd(estimator, X, y):
else:
raise NotImplementedError("lasso not implemented yet.")

linear = LinearModel(coef, intercept, loss=estimator.loss, regul=regul)
loss = "l2" if estimator.loss == "squared_error" else estimator.loss

linear = LinearModel(coef, intercept, loss=loss, regul=regul)
return linear, X, y


Expand Down Expand Up @@ -362,11 +364,15 @@ def linearize_mlp(estimator, X, y):
if y.ndim == 1:
y = y.reshape(-1, 1)

batch_size = (
min(200, X.shape[0]) if estimator.batch_size == "auto" else estimator.batch_size
)

linear = LinearModel(
coef,
intercept,
loss=loss,
regul=estimator.alpha * estimator.batch_size / X.shape[0],
regul=estimator.alpha * batch_size / X.shape[0],
)

return linear, activation, y
Expand Down

0 comments on commit 735dc7b

Please sign in to comment.