Skip to content

Commit

Permalink
Implement CrossFitEstimator.score
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 19, 2024
1 parent 9e00603 commit 5b3ed6e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
Changelog
=========

0.8.0 (2024-07-xx)
------------------

* Implement :meth:`metalearners.cross_fit_estimator.CrossFitEstimator.score`.


0.7.0 (2024-07-12)
------------------

Expand Down
25 changes: 23 additions & 2 deletions metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from sklearn.base import is_classifier
from sklearn.metrics import accuracy_score, r2_score
from sklearn.model_selection import (
KFold,
StratifiedKFold,
Expand Down Expand Up @@ -337,8 +338,28 @@ def predict_proba(
oos_method=oos_method,
)

def score(self, X, y, sample_weight=None, **kwargs):
raise NotImplementedError()
def score(
self,
X: Matrix,
y: Vector,
is_oos: bool,
oos_method: OosMethod | None = None,
sample_weight: Vector | None = None,
) -> float:
"""Return the coefficient of determination of the prediction if the estimator is
a regressor or the mean accuracy if it is a classifier."""
if self._estimator_type == "classifier":
return accuracy_score(
y, self.predict(X, is_oos, oos_method), sample_weight=sample_weight
)
elif self._estimator_type == "regressor":
return r2_score(
y, self.predict(X, is_oos, oos_method), sample_weight=sample_weight
)
else:
raise NotImplementedError(

Check warning on line 360 in metalearners/cross_fit_estimator.py

View check run for this annotation

Codecov / codecov/patch

metalearners/cross_fit_estimator.py#L360

Added line #L360 was not covered by tests
"score is not implemented for this type of estimator."
)

def set_params(self, **params):
raise NotImplementedError()
Expand Down
18 changes: 18 additions & 0 deletions tests/test_cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
from lightgbm import LGBMClassifier, LGBMRegressor
from sklearn.base import is_classifier, is_regressor
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import accuracy_score, log_loss
from sklearn.model_selection import KFold
Expand Down Expand Up @@ -262,3 +263,20 @@ def test_validate_data_match(n_observations, test_indices, success):
ValueError, match="rely on different numbers of observations"
):
_validate_data_match_prior_split(n_observations, test_indices)


@pytest.mark.parametrize(
"estimator",
[LGBMClassifier, LGBMRegressor],
)
def test_score_smoke(estimator, rng):
n_samples = 1000
X = rng.standard_normal((n_samples, 3))
if is_classifier(estimator):
y = rng.integers(0, 4, n_samples)
elif is_regressor(estimator):
y = rng.standard_normal(n_samples)

cfe = CrossFitEstimator(5, estimator, {"n_estimators": 3})
cfe.fit(X, y)
cfe.score(X, y, False)

0 comments on commit 5b3ed6e

Please sign in to comment.