Skip to content

Commit

Permalink
Remove git_root from packaged modules.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Jul 4, 2024
1 parent 274f40b commit a8e5fca
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _twins_data(rng, use_numpy=False, test_fraction=0.2):
feature_columns,
categorical_feature_columns,
true_cate_column,
) = load_twins_data(rng)
) = load_twins_data(Path(git_root()) / "data" / "twins.zip", rng)

covariates = chosen_df[feature_columns]
observed_outcomes = chosen_df[outcome_column]
Expand Down
11 changes: 7 additions & 4 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from collections.abc import Callable
from inspect import signature
from operator import le, lt
from pathlib import Path

import numpy as np
import pandas as pd
from git_root import git_root
from sklearn.base import check_array, check_X_y, is_classifier, is_regressor
from sklearn.ensemble import (
HistGradientBoostingClassifier,
Expand Down Expand Up @@ -326,10 +326,12 @@ def validate_valid_treatment_variant_not_control(
)


def load_mindset_data() -> tuple[pd.DataFrame, str, str, list[str], list[str]]:
def load_mindset_data(
path: Path,
) -> tuple[pd.DataFrame, str, str, list[str], list[str]]:
# TODO: Optionally make this function work with a URL instead of a file system reference.
# That way, we don't need to package the data for someone to be able to use this function.
df = pd.read_csv(git_root("data/learning_mindset.zip"))
df = pd.read_csv(path)
outcome_column = "achievement_score"
treatment_column = "intervention"
feature_columns = [
Expand Down Expand Up @@ -363,11 +365,12 @@ def load_mindset_data() -> tuple[pd.DataFrame, str, str, list[str], list[str]]:


def load_twins_data(
path: Path,
rng: np.random.Generator,
) -> tuple[pd.DataFrame, str, str, list[str], list[str], str]:
# TODO: Optionally make this function work with a URL instead of a file system reference.
# That way, we don't need to package the data for someone to be able to use this function.
df = pd.read_csv(git_root("data/twins.zip"))
df = pd.read_csv(path)
drop_columns = [
"bord",
"brstate_reg",
Expand Down
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) QuantCo 2024-2024
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path

import numpy as np
import pandas as pd
import pytest
from git_root import git_root

from metalearners._utils import get_linear_dimension, load_mindset_data, load_twins_data
from metalearners.data_generation import (
Expand Down Expand Up @@ -71,7 +74,7 @@ def rng():

@pytest.fixture(scope="session")
def mindset_data():
return load_mindset_data()
return load_mindset_data(Path(git_root()) / "data" / "learning_mindset.zip")


@pytest.fixture(scope="session")
Expand All @@ -84,7 +87,7 @@ def twins_data():
feature_columns,
categorical_feature_columns,
_,
) = load_twins_data(rng)
) = load_twins_data(Path(git_root()) / "data" / "twins.zip", rng)
return (
chosen_df,
outcome_column,
Expand Down

0 comments on commit a8e5fca

Please sign in to comment.