Skip to content

Commit

Permalink
Drop support for Python 3.8 (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein authored Mar 19, 2024
1 parent dd58173 commit 9a1fda0
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 28 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ jobs:
linux-unittests:
name: "Unit tests - Python ${{ matrix.PYTHON_VERSION }}"
timeout-minutes: 30
runs-on: ubuntu-latest
runs-on: ubuntu-latest-8core
strategy:
fail-fast: false
matrix:
PYTHON_VERSION: ["3.8", "3.9", "3.10", "3.11"]
PYTHON_VERSION: ["3.9", "3.10", "3.11"]
steps:
- name: Checkout branch
uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions conda.recipe/recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ build:

requirements:
host:
- python >=3.8
- python >=3.9
- pip
- setuptools-scm
run:
- python >=3.8
- python >=3.9
- scikit-learn >=1.3
- pandas
- numpy
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- git_root
## Python
- pip
- python>=3.8
- python>=3.9
- setuptools-scm
- setuptools>=61 # Adds support for pyproject.toml package declaration.
## Documentation
Expand Down
14 changes: 7 additions & 7 deletions metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: LicenseRef-QuantCo

from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
from typing import Literal, Optional, Union

import numpy as np
from sklearn.base import is_classifier
Expand Down Expand Up @@ -63,18 +63,18 @@ class CrossFitEstimator:
"""

n_folds: int
estimator_factory: Type[_ScikitModel]
estimator_params: Dict = field(default_factory=dict)
estimator_factory: type[_ScikitModel]
estimator_params: dict = field(default_factory=dict)
enable_overall: bool = True
_estimators: List[_ScikitModel] = field(init=False)
_estimators: list[_ScikitModel] = field(init=False)
_overall_estimator: Optional[_ScikitModel] = field(init=False)
_test_indices: Optional[Tuple[np.ndarray]] = field(init=False)
_test_indices: Optional[tuple[np.ndarray]] = field(init=False)

def __post__init__(self):
_validate_n_folds(self.n_folds)
self._estimators: List[_ScikitModel] = []
self._estimators: list[_ScikitModel] = []
self._overall_estimator: Optional[_ScikitModel] = None
self._test_indices: Optional[Tuple[np.ndarray]] = None
self._test_indices: Optional[tuple[np.ndarray]] = None

def _train_overall_estimator(
self, X: Matrix, y: Union[Matrix, Vector]
Expand Down
27 changes: 14 additions & 13 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
# SPDX-License-Identifier: LicenseRef-QuantCo

from abc import ABC, abstractmethod
from typing import Collection, Dict, List, Optional, Union
from collections.abc import Collection
from typing import Optional, Union

import numpy as np
from typing_extensions import Self

from metalearners._utils import Matrix, Vector, _ScikitModel, validate_number_positive
from metalearners.cross_fit_estimator import CrossFitEstimator

Params = Dict[str, Union[int, float, str]]
Params = dict[str, Union[int, float, str]]
Features = Union[Collection[str], Collection[int]]


def _initialize_model_dict(argument, expected_names: Collection[str]) -> Dict:
def _initialize_model_dict(argument, expected_names: Collection[str]) -> dict:
if isinstance(argument, dict) and set(argument.keys()) == set(expected_names):
return argument
return {name: argument for name in expected_names}
Expand All @@ -24,19 +25,19 @@ class MetaLearner(ABC):

@classmethod
@abstractmethod
def nuisance_model_names(cls) -> List[str]: ...
def nuisance_model_names(cls) -> list[str]: ...

@classmethod
@abstractmethod
def treatment_model_names(cls) -> List[str]: ...
def treatment_model_names(cls) -> list[str]: ...

def __init__(
self,
nuisance_model_factory: Union[_ScikitModel, Dict[str, _ScikitModel]],
treatment_model_factory: Union[_ScikitModel, Dict[str, _ScikitModel]],
nuisance_model_params: Optional[Union[Params, Dict[str, Params]]] = None,
treatment_model_params: Optional[Union[Params, Dict[str, Params]]] = None,
feature_set: Optional[Union[Features, Dict[str, Features]]] = None,
nuisance_model_factory: Union[_ScikitModel, dict[str, _ScikitModel]],
treatment_model_factory: Union[_ScikitModel, dict[str, _ScikitModel]],
nuisance_model_params: Optional[Union[Params, dict[str, Params]]] = None,
treatment_model_params: Optional[Union[Params, dict[str, Params]]] = None,
feature_set: Optional[Union[Features, dict[str, Features]]] = None,
# TODO: Consider implementing selection of number of folds for various estimators.
n_folds: int = 10,
):
Expand Down Expand Up @@ -92,15 +93,15 @@ def __init__(
feature_set, nuisance_model_names + treatment_model_names
)

self._nuisance_models: Dict[str, _ScikitModel] = {
self._nuisance_models: dict[str, _ScikitModel] = {
name: CrossFitEstimator(
n_folds=self.n_folds,
estimator_factory=self.nuisance_model_factory[name],
estimator_params=self.nuisance_model_params[name],
)
for name in nuisance_model_names
}
self._treatment_models: Dict[str, _ScikitModel] = {
self._treatment_models: dict[str, _ScikitModel] = {
name: CrossFitEstimator(
n_folds=self.n_folds,
estimator_factory=self.treatment_model_factory[name],
Expand Down Expand Up @@ -159,7 +160,7 @@ def predict(self, X: Matrix) -> np.ndarray:
...

@abstractmethod
def evaluate(self, X: Matrix, y: Vector, w: Vector) -> Dict[str, Union[float, int]]:
def evaluate(self, X: Matrix, y: Vector, w: Vector) -> dict[str, Union[float, int]]:
"""Evaluate all models contained in a MetaLearner."""
...

Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ authors = [
]
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
requires-python = ">=3.8"
requires-python = ">=3.9"

[project.urls]
repository = "https://github.com/quantco/metalearners"
Expand Down Expand Up @@ -67,7 +66,7 @@ select = [
]

[tool.mypy]
python_version = '3.8'
python_version = '3.9'
ignore_missing_imports = true
no_implicit_optional = true
check_untyped_defs = true
Expand Down

0 comments on commit 9a1fda0

Please sign in to comment.