From 31080a9b285d13ca64428872ee9a6e8d6d6a5e36 Mon Sep 17 00:00:00 2001 From: Kevin Maik Jablonka Date: Tue, 3 Oct 2023 15:38:14 +0200 Subject: [PATCH] regressor and do not return std by default --- src/chemlift/finetune/peftmodels.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/chemlift/finetune/peftmodels.py b/src/chemlift/finetune/peftmodels.py index 4225502..4967bce 100644 --- a/src/chemlift/finetune/peftmodels.py +++ b/src/chemlift/finetune/peftmodels.py @@ -285,7 +285,7 @@ def predict( temperature=0.7, do_sample=False, formatted: Optional[pd.DataFrame] = None, - return_std: bool = True, + return_std: bool = False, ): predictions = self._predict( X=X, temperature=temperature, do_sample=do_sample, formatted=formatted @@ -334,7 +334,7 @@ def _query(self, formatted_df, temperature, do_sample): class PEFTRegressor(PEFTClassifier): - def __init__( + def __init__( self, property_name: str, extractor: RegressionExtractor = RegressionExtractor(), @@ -376,8 +376,31 @@ def __init__( self.tune_settings["per_device_train_batch_size"] = self.batch_size - __repr__ = basic_repr("property_name", "_base_model", 'num_digits') + __repr__ = basic_repr("property_name", "_base_model", "num_digits") + def predict( + self, + X: Optional[ArrayLike] = None, + temperature=0.7, + do_sample=False, + formatted: Optional[pd.DataFrame] = None, + return_std: bool = False, + ): + predictions = self._predict( + X=X, temperature=temperature, do_sample=do_sample, formatted=formatted + ) + + predictions = np.array(predictions).T + + # nan values make issues here + predictions_mean = np.array( + [try_exccept_nan(np.mean, pred) for pred in predictions.astype(int)] + ) + + if return_std: + predictions_std = np.array([np.std(pred) for pred in predictions.astype(int)]) + return predictions_mean, predictions_std + return predictions_mean class SMILESAugmentedPEFTClassifier(PEFTClassifier):