From adc910c2554a65cbe20ec48601451641b8e5dd53 Mon Sep 17 00:00:00 2001 From: kklein Date: Thu, 5 Sep 2024 18:20:50 +0200 Subject: [PATCH] Appease mypy. --- metalearners/_utils.py | 18 +++++++++++++++++- metalearners/data_generation.py | 9 +++++++-- metalearners/drlearner.py | 3 ++- metalearners/metalearner.py | 2 +- metalearners/rlearner.py | 10 +++++++--- metalearners/tlearner.py | 2 +- metalearners/xlearner.py | 20 +++++++++++--------- 7 files changed, 46 insertions(+), 18 deletions(-) diff --git a/metalearners/_utils.py b/metalearners/_utils.py index 87143121..65c6ef62 100644 --- a/metalearners/_utils.py +++ b/metalearners/_utils.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +import polars as pl import scipy from sklearn.base import check_array, check_X_y, is_classifier, is_regressor from sklearn.ensemble import ( @@ -32,8 +33,15 @@ def safe_len(X: Matrix) -> int: return len(X) +def copy_matrix(matrix: Matrix) -> Matrix: + """Make a copy of a matrix.""" + if isinstance(matrix, pl.DataFrame): + return matrix.clone() + return matrix.copy() + + def index_matrix(matrix: Matrix, rows: Vector) -> Matrix: - """Subselect certain rows from a matrix.""" + """Subselect certain ows from a matrix.""" if isinstance(rows, pd.Series): rows = rows.to_numpy() if isinstance(matrix, pd.DataFrame): @@ -60,6 +68,14 @@ def are_pd_indices_equal(*args: pd.DataFrame | pd.Series) -> bool: return True +def to_np(data: Vector | Matrix) -> np.ndarray: + if isinstance(data, np.ndarray): + return data + if hasattr(data, "to_numpy"): + return data.to_numpy() + return np.array(data) + + def is_pd_df_or_series(arg) -> bool: return isinstance(arg, pd.DataFrame) or isinstance(arg, pd.Series) diff --git a/metalearners/data_generation.py b/metalearners/data_generation.py index 2e5d14ad..015b508d 100644 --- a/metalearners/data_generation.py +++ b/metalearners/data_generation.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +import polars as pl from scipy.stats import wishart from metalearners._typing import Matrix, Vector @@ -12,6 +13,7 @@ check_probability, check_propensity_score, convert_and_pad_propensity_score, + copy_matrix, default_rng, get_n_variants, sigmoid, @@ -239,8 +241,11 @@ def insert_missing( check_probability(missing_probability, zero_included=True) missing_mask = rng.binomial(1, p=missing_probability, size=X.shape).astype("bool") - masked = X.copy() - masked[missing_mask] = np.nan + masked = copy_matrix(X) + if isinstance(masked, pl.DataFrame): + raise ValueError() + else: + masked[missing_mask] = np.nan return masked diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index f2c67e15..56916471 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -29,6 +29,7 @@ index_matrix, infer_input_dict, safe_len, + to_np, validate_valid_treatment_variant_not_control, warning_experimental_feature, ) @@ -416,7 +417,7 @@ def _pseudo_outcome( y0_estimate = y0_estimate[:, 0] y1_estimate = y1_estimate[:, 0] - pseudo_outcome = ( + pseudo_outcome = to_np( ( (y - y1_estimate) / clip_element_absolute_value_to_epsilon( diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 4024341e..fe720e2b 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -1336,7 +1336,7 @@ def __init__( n_folds=n_folds, random_state=random_state, ) - self._treatment_variants_mask: list[np.ndarray] | None = None + self._treatment_variants_mask: list[Vector] | None = None def predict_conditional_average_outcomes( self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL diff --git a/metalearners/rlearner.py b/metalearners/rlearner.py index 95cc237e..93bfcd97 100644 --- a/metalearners/rlearner.py +++ b/metalearners/rlearner.py @@ -19,8 +19,10 @@ get_predict, get_predict_proba, index_matrix, + index_vector, infer_input_dict, safe_len, + to_np, validate_all_vectors_same_index, validate_valid_treatment_variant_not_control, warning_experimental_feature, @@ -516,14 +518,16 @@ def _pseudo_outcome_and_weights( y_residuals = y[mask] - y_estimates - w_binarized = w[mask] == treatment_variant + w_binarized = to_np(index_vector(w, mask) == treatment_variant) w_residuals = w_binarized - w_estimates_binarized w_residuals_padded = clip_element_absolute_value_to_epsilon( w_residuals, epsilon ) - pseudo_outcomes = y_residuals / w_residuals_padded - weights = np.square(w_residuals) + pseudo_outcomes = to_np(y_residuals / w_residuals_padded) + # In principle np.square could also return a scalar. + # We ensure that the type is np.ndarray. + weights = to_np(np.square(w_residuals)) return pseudo_outcomes, weights diff --git a/metalearners/tlearner.py b/metalearners/tlearner.py index 946d833f..8a947eb5 100644 --- a/metalearners/tlearner.py +++ b/metalearners/tlearner.py @@ -71,7 +71,7 @@ def fit_all_nuisance( self._validate_treatment(w) self._validate_outcome(y, w) - self._treatment_variants_mask = [] + self._treatment_variants_mask: list[Vector] = [] for v in range(self.n_variants): self._treatment_variants_mask.append(w == v) diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index f9ff5497..9e0b50d3 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -16,9 +16,11 @@ get_predict, get_predict_proba, index_matrix, + index_vector, infer_input_dict, infer_probabilities_output, safe_len, + to_np, validate_valid_treatment_variant_not_control, warning_experimental_feature, ) @@ -96,7 +98,7 @@ def fit_all_nuisance( self._validate_treatment(w) self._validate_outcome(y, w) - self._treatment_variants_mask = [] + self._treatment_variants_mask: list[Vector] = [] qualified_fit_params = self._qualified_fit_params(fit_params) @@ -421,12 +423,10 @@ def _pseudo_outcome( treatment_indices = w == treatment_variant control_indices = w == 0 - treatment_outcome = index_matrix( - conditional_average_outcome_estimates, control_indices - )[:, treatment_variant] - control_outcome = index_matrix( - conditional_average_outcome_estimates, treatment_indices - )[:, 0] + treatment_outcome = conditional_average_outcome_estimates[ + control_indices, treatment_variant + ] + control_outcome = conditional_average_outcome_estimates[treatment_indices, 0] if self.is_classification: # Get the probability of positive class, multiclass is currently not supported. @@ -436,8 +436,10 @@ def _pseudo_outcome( control_outcome = control_outcome[:, 0] treatment_outcome = treatment_outcome[:, 0] - imputed_te_treatment = y[treatment_indices] - control_outcome - imputed_te_control = treatment_outcome - y[control_indices] + imputed_te_treatment = ( + to_np(index_vector(y, treatment_indices)) - control_outcome + ) + imputed_te_control = treatment_outcome - to_np(index_vector(y, control_indices)) return imputed_te_control, imputed_te_treatment