From 4aabb771eb4d6a7a6a0450d8d9e6b0e390b91788 Mon Sep 17 00:00:00 2001 From: Pierre Nodet Date: Fri, 21 Feb 2025 01:57:42 +0100 Subject: [PATCH] nit in batch size computation --- mislabeled/probe/_linear.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mislabeled/probe/_linear.py b/mislabeled/probe/_linear.py index d9112bf..a33c0c3 100644 --- a/mislabeled/probe/_linear.py +++ b/mislabeled/probe/_linear.py @@ -364,9 +364,12 @@ def linearize_mlp(estimator, X, y): if y.ndim == 1: y = y.reshape(-1, 1) - batch_size = ( - min(200, X.shape[0]) if estimator.batch_size == "auto" else estimator.batch_size - ) + if estimator.solver == "lbfgs": + batch_size = X.shape[0] + elif estimator.batch_size == "auto": + batch_size = min(200, X.shape[0]) + else: + batch_size = estimator.batch_size linear = LinearModel( coef,