Skip to content

Commit

Permalink
add support for csr matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvalal committed Aug 16, 2024
1 parent fccde2c commit c1351d0
Show file tree
Hide file tree
Showing 12 changed files with 596 additions and 24 deletions.
550 changes: 550 additions & 0 deletions docs/examples/example_sparse_inputs.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ Examples
Estimating CATEs for survival analysis <example_survival.ipynb>
What if I know the propensity score? <example_propensity.ipynb>
Converting a MetaLearner to ONNX <example_onnx.ipynb>
Using Sparse Covariate Matrices <example_sparse_inputs.ipynb>
3 changes: 2 additions & 1 deletion metalearners/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
import scipy.sparse as sps

PredictMethod = Literal["predict", "predict_proba"]

Expand All @@ -21,7 +22,7 @@

# ruff is not happy about the usage of Union.
Vector = Union[pd.Series, np.ndarray] # noqa
Matrix = Union[pd.DataFrame, np.ndarray] # noqa
Matrix = Union[pd.DataFrame, np.ndarray, sps.csr_matrix] # noqa


class _ScikitModel(Protocol):
Expand Down
7 changes: 7 additions & 0 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
import scipy
from sklearn.base import check_array, check_X_y, is_classifier, is_regressor
from sklearn.ensemble import (
HistGradientBoostingClassifier,
Expand All @@ -24,6 +25,12 @@
default_rng = np.random.default_rng()


def safe_len(X):
if scipy.sparse.issparse(X):
return X.shape[0]
return len(X)


def index_matrix(matrix: Matrix, rows: Vector) -> Matrix:
"""Subselect certain rows from a matrix."""
if isinstance(rows, pd.Series):
Expand Down
19 changes: 12 additions & 7 deletions metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from typing_extensions import Self

from metalearners._typing import Matrix, OosMethod, PredictMethod, SplitIndices, Vector
from metalearners._utils import _ScikitModel, index_matrix, validate_number_positive
from metalearners._utils import (
_ScikitModel,
index_matrix,
safe_len,
validate_number_positive,
)

OVERALL: OosMethod = "overall"
MEDIAN: OosMethod = "median"
Expand Down Expand Up @@ -157,7 +162,7 @@ def fit(
(train_indices, test_indices) tuples indicating how to split the data at hand
into train and test/estimation sets for different folds.
"""
_validate_data_match_prior_split(len(X), self._test_indices)
_validate_data_match_prior_split(safe_len(X), self._test_indices)

if fit_params is None:
fit_params = dict()
Expand Down Expand Up @@ -215,13 +220,13 @@ def _n_outputs(self, method: PredictMethod) -> int:
def _predict_all(self, X: Matrix, method: PredictMethod) -> np.ndarray:
n_outputs = self._n_outputs(method)
predictions = self._initialize_prediction_tensor(
n_observations=len(X),
n_observations=safe_len(X),
n_outputs=n_outputs,
n_folds=self.n_folds,
)
for i, estimator in enumerate(self._estimators):
predictions[:, :, i] = np.reshape(
getattr(estimator, method)(X), (len(X), n_outputs)
getattr(estimator, method)(X), (safe_len(X), n_outputs)
)
if n_outputs == 1:
return predictions[:, 0, :]
Expand All @@ -242,15 +247,15 @@ def _predict_in_sample(
) -> np.ndarray:
if not self._test_indices:
raise ValueError()
if len(X) != sum(len(fold) for fold in self._test_indices):
if safe_len(X) != sum(len(fold) for fold in self._test_indices):
raise ValueError(
"Trying to predict in-sample on data that is unlike data encountered in training. "
f"Training data included {sum(len(fold) for fold in self._test_indices)} "
f"observations while prediction data includes {len(X)} observations."
f"observations while prediction data includes {safe_len(X)} observations."
)
n_outputs = self._n_outputs(method)
predictions = self._initialize_prediction_tensor(
n_observations=len(X),
n_observations=safe_len(X),
n_outputs=n_outputs,
n_folds=1,
)
Expand Down
7 changes: 4 additions & 3 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_predict_proba,
index_matrix,
infer_input_dict,
safe_len,
validate_valid_treatment_variant_not_control,
warning_experimental_feature,
)
Expand Down Expand Up @@ -253,7 +254,7 @@ def predict(
oos_method: OosMethod = OVERALL,
) -> np.ndarray:
n_outputs = 2 if self.is_classification else 1
estimates = np.zeros((len(X), self.n_variants - 1, n_outputs))
estimates = np.zeros((safe_len(X), self.n_variants - 1, n_outputs))
for treatment_variant in range(1, self.n_variants):
estimates_variant = self.predict_treatment(
X,
Expand Down Expand Up @@ -365,7 +366,7 @@ def average_treatment_effect(
raise ValueError(
"The nuisance models need to be fitted before computing the treatment effect."
)
gamma_matrix = np.zeros((len(X), self.n_variants - 1))
gamma_matrix = np.zeros((safe_len(X), self.n_variants - 1))
for treatment_variant in range(1, self.n_variants):
gamma_matrix[:, treatment_variant - 1] = self._pseudo_outcome(
X=X,
Expand All @@ -375,7 +376,7 @@ def average_treatment_effect(
is_oos=is_oos,
)
treatment_effect = gamma_matrix.mean(axis=0)
standard_error = gamma_matrix.std(axis=0) / np.sqrt(len(X))
standard_error = gamma_matrix.std(axis=0) / np.sqrt(safe_len(X))
return treatment_effect, standard_error

def _pseudo_outcome(
Expand Down
4 changes: 2 additions & 2 deletions metalearners/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import shap

from metalearners._typing import Matrix, _ScikitModel
from metalearners._utils import simplify_output_2d
from metalearners._utils import safe_len, simplify_output_2d
from metalearners.metalearner import Params


Expand Down Expand Up @@ -59,7 +59,7 @@ def from_estimates(
The ``cate_estimates`` should be the raw outcome of a MetaLearner with 3 dimensions
and should not be simplified.
"""
if len(X) != len(cate_estimates) or len(X) == 0:
if safe_len(X) != len(cate_estimates) or safe_len(X) == 0:
raise ValueError(
"X and cate_estimates should contain the same number of observations "
"and not be empty."
Expand Down
5 changes: 3 additions & 2 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ONNX_PROBABILITIES_OUTPUTS,
default_metric,
index_matrix,
safe_len,
validate_model_and_predict_method,
validate_number_positive,
)
Expand Down Expand Up @@ -120,7 +121,7 @@ def _filter_x_columns(X: Matrix, feature_set: Features) -> Matrix:
if feature_set is None:
X_filtered = X
elif len(feature_set) == 0:
X_filtered = np.ones((len(X), 1))
X_filtered = np.ones((safe_len(X), 1))
else:
if isinstance(X, pd.DataFrame):
X_filtered = X[list(feature_set)]
Expand Down Expand Up @@ -1347,7 +1348,7 @@ def predict_conditional_average_outcomes(
"typically set during fitting, is None."
)
# TODO: Consider multiprocessing
n_obs = len(X)
n_obs = safe_len(X)
nuisance_tensors = self._nuisance_tensors(n_obs)
conditional_average_outcomes_list = nuisance_tensors[VARIANT_OUTCOME_MODEL]

Expand Down
9 changes: 5 additions & 4 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_predict_proba,
index_matrix,
infer_input_dict,
safe_len,
validate_all_vectors_same_index,
validate_valid_treatment_variant_not_control,
warning_experimental_feature,
Expand Down Expand Up @@ -277,7 +278,7 @@ def predict(
oos_method: OosMethod = OVERALL,
) -> np.ndarray:
n_outputs = 2 if self.is_classification else 1
tau_hat = np.zeros((len(X), self.n_variants - 1, n_outputs))
tau_hat = np.zeros((safe_len(X), self.n_variants - 1, n_outputs))

if is_oos:

Expand All @@ -298,7 +299,7 @@ def predict(
variant_estimates = np.stack(
[-variant_estimates, variant_estimates], axis=-1
)
variant_estimates = variant_estimates.reshape(len(X), n_outputs)
variant_estimates = variant_estimates.reshape(safe_len(X), n_outputs)
tau_hat[:, treatment_variant - 1, :] = variant_estimates

return tau_hat
Expand Down Expand Up @@ -486,7 +487,7 @@ def _pseudo_outcome_and_weights(
constant ``epsilon`` to the denominator in order to avoid numerical problems.
"""
if mask is None:
mask = np.ones(len(X), dtype=bool)
mask = np.ones(safe_len(X), dtype=bool)

validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants)

Expand Down Expand Up @@ -560,7 +561,7 @@ def predict_conditional_average_outcomes(
where :math:`K` is the number of treatment variants.
"""
n_obs = len(X)
n_obs = safe_len(X)

cate_estimates = self.predict(
X=X,
Expand Down
3 changes: 2 additions & 1 deletion metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from metalearners._utils import (
convert_treatment,
get_one,
safe_len,
supports_categoricals,
)
from metalearners.cross_fit_estimator import OVERALL, CrossFitEstimator
Expand Down Expand Up @@ -231,7 +232,7 @@ def evaluate(
def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
n_obs = len(X)
n_obs = safe_len(X)
conditional_average_outcomes_list = []

for treatment_variant in range(self.n_variants):
Expand Down
5 changes: 4 additions & 1 deletion metalearners/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import Self

from metalearners._typing import Matrix, Vector
from metalearners._utils import safe_len
from metalearners.drlearner import DRLearner
from metalearners.metalearner import MetaLearner
from metalearners.rlearner import RLearner
Expand Down Expand Up @@ -104,4 +105,6 @@ def predict(self, X: Matrix) -> np.ndarray[Any, Any]:
return np.argmax(self.predict_proba(X), axis=1)

def predict_proba(self, X: pd.DataFrame) -> np.ndarray[Any, Any]:
return np.full((len(X), 2), [1 - self.propensity_score, self.propensity_score])
return np.full(
(safe_len(X), 2), [1 - self.propensity_score, self.propensity_score]
)
7 changes: 4 additions & 3 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
index_matrix,
infer_input_dict,
infer_probabilities_output,
safe_len,
validate_valid_treatment_variant_not_control,
warning_experimental_feature,
)
Expand Down Expand Up @@ -231,7 +232,7 @@ def predict(
"typically set during fitting, is None."
)
n_outputs = 2 if self.is_classification else 1
tau_hat = np.zeros((len(X), self.n_variants - 1, n_outputs))
tau_hat = np.zeros((safe_len(X), self.n_variants - 1, n_outputs))
# Propensity score model is always a classifier so we can't use MEDIAN
propensity_score_oos = OVERALL if oos_method == MEDIAN else oos_method
propensity_score = self.predict_nuisance(
Expand Down Expand Up @@ -266,8 +267,8 @@ def predict(
oos_method=oos_method,
)
else:
tau_hat_treatment = np.zeros(len(X))
tau_hat_control = np.zeros(len(X))
tau_hat_treatment = np.zeros(safe_len(X))
tau_hat_control = np.zeros(safe_len(X))

tau_hat_treatment[non_treatment_variant_indices] = (
self.predict_treatment(
Expand Down

0 comments on commit c1351d0

Please sign in to comment.