From 47ec41767b30ae19a552afd58ae60f0d8c2bb60f Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Wed, 20 Sep 2023 09:42:06 +0300 Subject: [PATCH] Added '_BaseEnsemble.recaller' attribute --- sklearn2pmml/ensemble/__init__.py | 36 ++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/sklearn2pmml/ensemble/__init__.py b/sklearn2pmml/ensemble/__init__.py index 5a25ea8..bfa2092 100644 --- a/sklearn2pmml/ensemble/__init__.py +++ b/sklearn2pmml/ensemble/__init__.py @@ -181,7 +181,7 @@ def predict_proba(self, X, **predict_proba_params): class _BaseEnsemble(_BaseComposition): - def __init__(self, steps): + def __init__(self, steps, recaller): for step in steps: if type(step) is not tuple: raise TypeError("Step is not a tuple") @@ -191,6 +191,7 @@ def __init__(self, steps): if not isinstance(predicate, (str, Predicate)): raise TypeError() self.steps = steps + self.recaller = recaller @property def _steps(self): @@ -207,6 +208,11 @@ def set_params(self, **kwargs): self._set_params("_steps", **kwargs) return self + def _to_evaluation_dataset(self, X): + if self.recaller is not None: + return self.recaller.transform(X) + return X + def _to_sparse(X, step_mask, step_result): # Make array if len(step_result.shape) == 1: @@ -256,8 +262,8 @@ def augment(self, X): class EstimatorChain(_BaseEnsemble): - def __init__(self, steps, multioutput = True): - super(EstimatorChain, self).__init__(steps) + def __init__(self, steps, recaller = None, multioutput = True): + super(EstimatorChain, self).__init__(steps, recaller) self.multioutput = multioutput def fit(self, X, y, **fit_params): @@ -265,9 +271,10 @@ def fit(self, X, y, **fit_params): if len(self.steps) != y.shape[1]: raise ValueError() y = numpy.asarray(y) + X_eval = self._to_evaluation_dataset(X) i = 0 for name, estimator, predicate in self.steps: - step_mask = eval_expr_rows(X, predicate, dtype = bool) + step_mask = eval_expr_rows(X_eval, predicate, dtype = bool) if numpy.sum(step_mask) < 1: raise ValueError(predicate) step_X = X[step_mask] @@ -286,8 +293,9 @@ def fit(self, X, y, **fit_params): def _predict(self, X, predict_method): result = None + X_eval = self._to_evaluation_dataset(X) for name, estimator, predicate in self.steps: - step_mask = eval_expr_rows(X, predicate, dtype = bool) + step_mask = eval_expr_rows(X_eval, predicate, dtype = bool) if numpy.sum(step_mask) < 1: continue step_X = X[step_mask] @@ -316,13 +324,14 @@ def predict_proba(self, X): class SelectFirstEstimator(_BaseEnsemble): - def __init__(self, steps): - super(SelectFirstEstimator, self).__init__(steps) + def __init__(self, steps, recaller): + super(SelectFirstEstimator, self).__init__(steps, recaller) def fit(self, X, y, **fit_params): + X_eval = self._to_evaluation_dataset(X) mask = numpy.zeros(X.shape[0], dtype = bool) for name, estimator, predicate in self.steps: - step_mask = eval_expr_rows(X, predicate, dtype = bool) + step_mask = eval_expr_rows(X_eval, predicate, dtype = bool) step_mask[mask] = False if numpy.sum(step_mask) < 1: raise ValueError(predicate) @@ -335,9 +344,10 @@ def fit(self, X, y, **fit_params): def _predict(self, X, predict_method): result = None + X_eval = self._to_evaluation_dataset(X) mask = numpy.zeros(X.shape[0], dtype = bool) for name, estimator, predicate in self.steps: - step_mask = eval_expr_rows(X, predicate, dtype = bool) + step_mask = eval_expr_rows(X_eval, predicate, dtype = bool) step_mask[mask] = False if numpy.sum(step_mask) < 1: continue @@ -359,13 +369,13 @@ def predict(self, X): class SelectFirstRegressor(SelectFirstEstimator, RegressorMixin): - def __init__(self, steps): - super(SelectFirstRegressor, self).__init__(steps) + def __init__(self, steps, recaller = None): + super(SelectFirstRegressor, self).__init__(steps, recaller) class SelectFirstClassifier(SelectFirstEstimator, ClassifierMixin): - def __init__(self, steps): - super(SelectFirstClassifier, self).__init__(steps) + def __init__(self, steps, recaller = None): + super(SelectFirstClassifier, self).__init__(steps, recaller) def predict_proba(self, X): return self._predict(X, "predict_proba")