From a8e5fca185cac229aa08aab23b093650ec1c7c68 Mon Sep 17 00:00:00 2001 From: kklein Date: Thu, 4 Jul 2024 18:56:18 +0200 Subject: [PATCH] Remove git_root from packaged modules. --- benchmarks/benchmark.py | 2 +- metalearners/_utils.py | 11 +++++++---- tests/conftest.py | 7 +++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 7fa56b6..e0c378d 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -115,7 +115,7 @@ def _twins_data(rng, use_numpy=False, test_fraction=0.2): feature_columns, categorical_feature_columns, true_cate_column, - ) = load_twins_data(rng) + ) = load_twins_data(Path(git_root()) / "data" / "twins.zip", rng) covariates = chosen_df[feature_columns] observed_outcomes = chosen_df[outcome_column] diff --git a/metalearners/_utils.py b/metalearners/_utils.py index a4c7187..60eb5d5 100644 --- a/metalearners/_utils.py +++ b/metalearners/_utils.py @@ -4,10 +4,10 @@ from collections.abc import Callable from inspect import signature from operator import le, lt +from pathlib import Path import numpy as np import pandas as pd -from git_root import git_root from sklearn.base import check_array, check_X_y, is_classifier, is_regressor from sklearn.ensemble import ( HistGradientBoostingClassifier, @@ -326,10 +326,12 @@ def validate_valid_treatment_variant_not_control( ) -def load_mindset_data() -> tuple[pd.DataFrame, str, str, list[str], list[str]]: +def load_mindset_data( + path: Path, +) -> tuple[pd.DataFrame, str, str, list[str], list[str]]: # TODO: Optionally make this function work with a URL instead of a file system reference. # That way, we don't need to package the data for someone to be able to use this function. - df = pd.read_csv(git_root("data/learning_mindset.zip")) + df = pd.read_csv(path) outcome_column = "achievement_score" treatment_column = "intervention" feature_columns = [ @@ -363,11 +365,12 @@ def load_mindset_data() -> tuple[pd.DataFrame, str, str, list[str], list[str]]: def load_twins_data( + path: Path, rng: np.random.Generator, ) -> tuple[pd.DataFrame, str, str, list[str], list[str], str]: # TODO: Optionally make this function work with a URL instead of a file system reference. # That way, we don't need to package the data for someone to be able to use this function. - df = pd.read_csv(git_root("data/twins.zip")) + df = pd.read_csv(path) drop_columns = [ "bord", "brstate_reg", diff --git a/tests/conftest.py b/tests/conftest.py index fa44727..2e6bb02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,12 @@ # Copyright (c) QuantCo 2024-2024 # SPDX-License-Identifier: BSD-3-Clause +from pathlib import Path + import numpy as np import pandas as pd import pytest +from git_root import git_root from metalearners._utils import get_linear_dimension, load_mindset_data, load_twins_data from metalearners.data_generation import ( @@ -71,7 +74,7 @@ def rng(): @pytest.fixture(scope="session") def mindset_data(): - return load_mindset_data() + return load_mindset_data(Path(git_root()) / "data" / "learning_mindset.zip") @pytest.fixture(scope="session") @@ -84,7 +87,7 @@ def twins_data(): feature_columns, categorical_feature_columns, _, - ) = load_twins_data(rng) + ) = load_twins_data(Path(git_root()) / "data" / "twins.zip", rng) return ( chosen_df, outcome_column,