Skip to content

Commit

Permalink
Add type hints for cv-split-related attributes.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Aug 16, 2024
1 parent 6803096 commit b005eb7
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from joblib import Parallel, delayed
from typing_extensions import Self

from metalearners._typing import Matrix, OosMethod, Scoring, Vector, _ScikitModel
from metalearners._typing import (
Matrix,
OosMethod,
Scoring,
SplitIndices,
Vector,
_ScikitModel,
)
from metalearners._utils import (
check_spox_installed,
copydoc,
Expand Down Expand Up @@ -105,8 +112,8 @@ def fit_all_nuisance(
"The X-Learner does not support synchronize_cross_fitting=False."
)

self._cv_split_indices = self._split(X)
self._treatment_cv_split_indices = {}
self._cv_split_indices: SplitIndices = self._split(X)
self._treatment_cv_split_indices: dict[int, SplitIndices] = {}

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
Expand Down

0 comments on commit b005eb7

Please sign in to comment.