Skip to content

Commit

Permalink
Merge pull request #13 from Morisset/devel
Browse files Browse the repository at this point in the history
Devel
  • Loading branch information
Morisset authored May 8, 2023
2 parents 0dd71f1 + e565b51 commit fdfd6b6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions ai4neb/Regressor/RegressionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,10 @@ def _set_scaler(self, force=False):
else:
self.scaler = StandardScaler()
self.scaler.fit(self.X_train)
# ToDO: PCA needs to be applied to scaled data. This is not the case.

if self.pca_N != 0:
self.pca = PCA(n_components=self.pca_N)
self.pca.fit(self.X_train)
self.pca.fit(self.scaler.transform(self.X_train))

def _set_scaler_y(self, force=False):
if ((self.scaler_y is None) or force) and self.scaling_y and self.y_train is not None:
Expand Down Expand Up @@ -769,7 +769,7 @@ def plot_loss(self, ax=None, i_RM=0, **kwargs):
ax.plot(val_loss_values, label='Validation loss')
ax.set_yscale('log')

def predict(self, scoring=False, reduce_by=None):
def predict(self, scoring=False, reduce_by=None, **kwargs):
"""
Compute the prediction using self.X_test
Results are stored into self.pred
Expand All @@ -786,7 +786,7 @@ def predict(self, scoring=False, reduce_by=None):
if self.predict_functional:
self.pred = self.RMs[0](self.X_test)[0,::]
else:
self.pred = self.RMs[0].predict(self.X_test)
self.pred = self.RMs[0].predict(self.X_test, **kwargs)
else:
self.pred = []
for i_RM, RM in enumerate(self.RMs):
Expand All @@ -797,7 +797,7 @@ def predict(self, scoring=False, reduce_by=None):
if self.predict_functional:
self.pred.append(RM(to_predict))[0,::]
else:
self.pred.append(RM.predict(to_predict))
self.pred.append(RM.predict(to_predict, **kwargs))
self.pred = np.array(self.pred).T
if scoring:
if self.N_test != self.N_test_y:
Expand Down
2 changes: 1 addition & 1 deletion ai4neb/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
# ai4neb version

__version__="0.2.12"
__version__="0.2.13"

0 comments on commit fdfd6b6

Please sign in to comment.