Skip to content

Commit

Permalink
Merge branch 'main' into onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 5, 2024
2 parents 43c3e69 + 5bbadf8 commit d870e9c
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 10 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ Changelog
**New features**

* Added ``scoring`` parameter to :meth:`metalearners.metalearner.MetaLearner.evaluate` and
implemented the abstract method for the :class:`metalearners.XLearner` and
* Add ``scoring`` parameter to :meth:`metalearners.metalearner.MetaLearner.evaluate` and
implement the abstract method for the :class:`metalearners.XLearner` and
:class:`metalearners.DRLearner`.

**Other changes**

* Increase lower bound on ``scikit-learn`` from 1.3 to 1.4.

* Drop the run dependency on ``git_root``.


0.5.0 (2024-06-18)
------------------

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _twins_data(rng, use_numpy=False, test_fraction=0.2):
feature_columns,
categorical_feature_columns,
true_cate_column,
) = load_twins_data(rng)
) = load_twins_data(Path(git_root()) / "data" / "twins.zip", rng)

covariates = chosen_df[feature_columns]
observed_outcomes = chosen_df[outcome_column]
Expand Down
11 changes: 7 additions & 4 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from collections.abc import Callable
from inspect import signature
from operator import le, lt
from pathlib import Path

import numpy as np
import pandas as pd
from git_root import git_root
from sklearn.base import check_array, check_X_y, is_classifier, is_regressor
from sklearn.ensemble import (
HistGradientBoostingClassifier,
Expand Down Expand Up @@ -326,10 +326,12 @@ def validate_valid_treatment_variant_not_control(
)


def load_mindset_data() -> tuple[pd.DataFrame, str, str, list[str], list[str]]:
def load_mindset_data(
path: Path,
) -> tuple[pd.DataFrame, str, str, list[str], list[str]]:
# TODO: Optionally make this function work with a URL instead of a file system reference.
# That way, we don't need to package the data for someone to be able to use this function.
df = pd.read_csv(git_root("data/learning_mindset.zip"))
df = pd.read_csv(path)
outcome_column = "achievement_score"
treatment_column = "intervention"
feature_columns = [
Expand Down Expand Up @@ -363,11 +365,12 @@ def load_mindset_data() -> tuple[pd.DataFrame, str, str, list[str], list[str]]:


def load_twins_data(
path: Path,
rng: np.random.Generator,
) -> tuple[pd.DataFrame, str, str, list[str], list[str], str]:
# TODO: Optionally make this function work with a URL instead of a file system reference.
# That way, we don't need to package the data for someone to be able to use this function.
df = pd.read_csv(git_root("data/twins.zip"))
df = pd.read_csv(path)
drop_columns = [
"bord",
"brstate_reg",
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ dependencies = [
"pandas",
"numpy",
"typing-extensions",
"git_root",
"shap",
"joblib>=1.2.0"
]
Expand Down
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) QuantCo 2024-2024
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path

import numpy as np
import pandas as pd
import pytest
from git_root import git_root
from sklearn.calibration import CalibratedClassifierCV
from sklearn.discriminant_analysis import (
LinearDiscriminantAnalysis,
Expand Down Expand Up @@ -194,7 +197,7 @@ def rng():

@pytest.fixture(scope="session")
def mindset_data():
return load_mindset_data()
return load_mindset_data(Path(git_root()) / "data" / "learning_mindset.zip")


@pytest.fixture(scope="session")
Expand All @@ -207,7 +210,7 @@ def twins_data():
feature_columns,
categorical_feature_columns,
_,
) = load_twins_data(rng)
) = load_twins_data(Path(git_root()) / "data" / "twins.zip", rng)
return (
chosen_df,
outcome_column,
Expand Down

0 comments on commit d870e9c

Please sign in to comment.