Skip to content

Commit

Permalink
Merge branch 'main' into fix_65
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 22, 2024
2 parents e5a0c58 + 95e2cfe commit db9c46c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ Changelog
0.8.0 (2024-07-xx)
------------------

* Implement :meth:`metalearners.cross_fit_estimator.CrossFitEstimator.score`.

**Bug fixes**

* Fixed a bug in :meth:`metalearners.metalearner.MetaLearner.evaluate` where it failed
in the case of ``feature_set`` being different from ``None``.


0.7.0 (2024-07-12)
------------------

Expand Down
27 changes: 24 additions & 3 deletions metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from functools import partial

import numpy as np
from sklearn.base import is_classifier
from sklearn.base import is_classifier, is_regressor
from sklearn.metrics import accuracy_score, r2_score
from sklearn.model_selection import (
KFold,
StratifiedKFold,
Expand Down Expand Up @@ -337,8 +338,28 @@ def predict_proba(
oos_method=oos_method,
)

def score(self, X, y, sample_weight=None, **kwargs):
raise NotImplementedError()
def score(
self,
X: Matrix,
y: Vector,
is_oos: bool,
oos_method: OosMethod | None = None,
sample_weight: Vector | None = None,
) -> float:
"""Return the coefficient of determination of the prediction if the estimator is
a regressor or the mean accuracy if it is a classifier."""
if is_classifier(self):
return accuracy_score(
y, self.predict(X, is_oos, oos_method), sample_weight=sample_weight
)
elif is_regressor(self):
return r2_score(
y, self.predict(X, is_oos, oos_method), sample_weight=sample_weight
)
else:
raise NotImplementedError(
"score is not implemented for this type of estimator."
)

def set_params(self, **params):
raise NotImplementedError()
Expand Down
18 changes: 18 additions & 0 deletions tests/test_cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
from lightgbm import LGBMClassifier, LGBMRegressor
from sklearn.base import is_classifier, is_regressor
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import accuracy_score, log_loss
from sklearn.model_selection import KFold
Expand Down Expand Up @@ -262,3 +263,20 @@ def test_validate_data_match(n_observations, test_indices, success):
ValueError, match="rely on different numbers of observations"
):
_validate_data_match_prior_split(n_observations, test_indices)


@pytest.mark.parametrize(
"estimator",
[LGBMClassifier, LGBMRegressor],
)
def test_score_smoke(estimator, rng):
n_samples = 1000
X = rng.standard_normal((n_samples, 3))
if is_classifier(estimator):
y = rng.integers(0, 4, n_samples)
elif is_regressor(estimator):
y = rng.standard_normal(n_samples)

cfe = CrossFitEstimator(5, estimator, {"n_estimators": 3})
cfe.fit(X, y)
cfe.score(X, y, False)

0 comments on commit db9c46c

Please sign in to comment.