Skip to content

Commit

Permalink
nit in batch size computation
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrenodet committed Feb 21, 2025
1 parent cd773ca commit 4aabb77
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions mislabeled/probe/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,12 @@ def linearize_mlp(estimator, X, y):
if y.ndim == 1:
y = y.reshape(-1, 1)

batch_size = (
min(200, X.shape[0]) if estimator.batch_size == "auto" else estimator.batch_size
)
if estimator.solver == "lbfgs":
batch_size = X.shape[0]
elif estimator.batch_size == "auto":
batch_size = min(200, X.shape[0])
else:
batch_size = estimator.batch_size

linear = LinearModel(
coef,
Expand Down

0 comments on commit 4aabb77

Please sign in to comment.