Skip to content

Commit

Permalink
Help mypy with to_numpy.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Sep 5, 2024
1 parent 95b5058 commit 1ef9dfd
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,15 @@ def convert_and_pad_propensity_score(
propensity score per variant. The expansion assumes that the provided scores are
those for the second variant.
"""
if isinstance(propensity_scores, pd.Series) or isinstance(
propensity_scores, pd.DataFrame
):
propensity_scores = propensity_scores.to_numpy()
p_is_1d = len(propensity_scores.shape) == 1 or propensity_scores.shape[1] == 1
if isinstance(propensity_scores, np.ndarray):
np_propensity_scores = propensity_scores
else:
np_propensity_scores = propensity_scores.to_numpy()

p_is_1d = len(np_propensity_scores.shape) == 1 or np_propensity_scores.shape[1] == 1
if n_variants == 2 and p_is_1d:
propensity_scores = np.c_[1 - propensity_scores, propensity_scores]
return propensity_scores
np_propensity_scores = np.c_[1 - np_propensity_scores, np_propensity_scores]
return np_propensity_scores


def get_n_variants(propensity_scores: Matrix) -> int:
Expand Down

0 comments on commit 1ef9dfd

Please sign in to comment.