From 440dafb270441fba81fc8d063997732804b55240 Mon Sep 17 00:00:00 2001 From: jeandut Date: Tue, 13 Aug 2024 19:55:58 +0200 Subject: [PATCH] Jean/fed kaplan (#44) * general architecture of fedkaplan * respecting naming conventions * refactoring preprocessing * passing in my head but not tested * adding test for KM utils * add credit * everything works in my head * hacking * some refactoring * fixing bug * fixing various stuff * fixing stuff * everything passing * test passing * trying fixing tests * linting * linting * linting * linting * linting * linting * linting fedkaplan * trying to finally fix linting * linting * fixing substra stuff in FedKM * test almost working wo weights * linting * linting * now tests are not passing only because grid is not the same * tests passing * weights working * removing useless comments * tests should be passing * removing forgoteen brakpoint --- fedeca/algorithms/torch_webdisco_algo.py | 168 ++------- fedeca/strategies/fed_kaplan.py | 257 +++++++++++++ fedeca/strategies/fed_smd.py | 296 +++++++++++++++ fedeca/strategies/webdisco.py | 26 +- fedeca/tests/strategies/test_km.py | 145 ++++++++ fedeca/tests/strategies/test_smd.py | 152 ++++++++ fedeca/tests/test_km.py | 43 +++ fedeca/utils/data_utils.py | 1 - fedeca/utils/moments_utils.py | 41 ++- fedeca/utils/substrafl_utils.py | 7 + fedeca/utils/survival_utils.py | 436 +++++++++++++++++++++++ 11 files changed, 1403 insertions(+), 169 deletions(-) create mode 100644 fedeca/strategies/fed_kaplan.py create mode 100644 fedeca/strategies/fed_smd.py create mode 100644 fedeca/tests/strategies/test_km.py create mode 100644 fedeca/tests/strategies/test_smd.py create mode 100644 fedeca/tests/test_km.py diff --git a/fedeca/algorithms/torch_webdisco_algo.py b/fedeca/algorithms/torch_webdisco_algo.py index a329920f..5d34f2b1 100644 --- a/fedeca/algorithms/torch_webdisco_algo.py +++ b/fedeca/algorithms/torch_webdisco_algo.py @@ -1,7 +1,6 @@ """Implement webdisco algorithm with Torch.""" import copy from copy import deepcopy -from math import sqrt from pathlib import Path from typing import Any, List, Optional, Union @@ -11,7 +10,6 @@ from autograd import elementwise_grad from autograd import numpy as anp from lifelines.utils import StepSizer -from pandas.api.types import is_numeric_dtype from scipy.linalg import norm from scipy.linalg import solve as spsolve from substrafl.algorithms.pytorch import weight_manager @@ -21,7 +19,11 @@ from fedeca.schemas import WebDiscoAveragedStates, WebDiscoSharedState from fedeca.utils.moments_utils import compute_uncentered_moment -from fedeca.utils.survival_utils import MockStepSizer +from fedeca.utils.survival_utils import ( + MockStepSizer, + build_X_y_function, + compute_X_y_and_propensity_weights_function, +) class TorchWebDiscoAlgo(TorchAlgo): @@ -597,124 +599,6 @@ def summary(self): summary = super().summary() return summary - def build_X_y(self, data_from_opener, shared_state={}): - """Build appropriate X and y times from output of opener. - - This function 1. uses the event column to inject the censorship - information present in the duration column (given in absolute values) - in the form of a negative sign. - 2. Drop every covariate except treatment if self.strategy == "iptw". - 3. Standardize the data if self.standardize_data AND if it receives - an outmodel. - 4. Return the (unstandardized) input to the propensity model Xprop if - necessary as well as the treated column to be able to compute the - propensity weights. - - Parameters - ---------- - data_from_opener : pd.DataFrame - The output of the opener - shared_state : dict, optional - Outmodel containing global means and stds. - by default {} - - Returns - ------- - tuple - standardized X, signed times, treatment column and unstandardized - propensity model input - """ - # We need y to be in the format (2*event-1)*duration - data_from_opener["time_multiplier"] = [ - 2.0 * e - 1.0 for e in data_from_opener[self._event_col].tolist() - ] - # No funny business irrespective of the convention used - y = ( - np.abs(data_from_opener[self._duration_col]) - * data_from_opener["time_multiplier"] - ) - y = y.to_numpy().astype("float64") - data_from_opener = data_from_opener.drop(columns=["time_multiplier"]) - # dangerous but we need to do it - string_columns = [ - col - for col in data_from_opener.columns - if not (is_numeric_dtype(data_from_opener[col])) - ] - data_from_opener = data_from_opener.drop(columns=string_columns) - - # We drop the targets from X - columns_to_drop = self._target_cols - X = data_from_opener.drop(columns=columns_to_drop) - if self._propensity_model is not None: - assert self._treated_col is not None - if self._training_strategy == "iptw": - X = X.loc[:, [self._treated_col]] - elif self._training_strategy == "aiptw": - if len(self._cox_fit_cols) > 0: - X = X.loc[:, [self._treated_col] + self._cox_fit_cols] - else: - pass - else: - assert self._training_strategy == "webdisco" - if len(self._cox_fit_cols) > 0: - X = X.loc[:, self._cox_fit_cols] - else: - pass - - # If X is to be standardized we do it - if self._standardize_data: - if shared_state: - # Careful this shouldn't happen apart from the predict - means = shared_state["global_uncentered_moment_1"] - vars = shared_state["global_centered_moment_2"] - # Careful we need to match pandas and use unbiased estimator - bias_correction = (shared_state["total_n_samples"]) / float( - shared_state["total_n_samples"] - 1 - ) - self.global_moments = { - "means": means, - "vars": vars, - "bias_correction": bias_correction, - } - stds = vars.transform(lambda x: sqrt(x * bias_correction + self._tol)) - X = X.sub(means) - X = X.div(stds) - else: - X = X.sub(self.global_moments["means"]) - stds = self.global_moments["vars"].transform( - lambda x: sqrt( - x * self.global_moments["bias_correction"] + self._tol - ) - ) - X = X.div(stds) - - X = X.to_numpy().astype("float64") - - # If we have a propensity model we need to build X without the targets AND the - # treated column - if self._propensity_model is not None: - # We do not normalize the data for the propensity model !!! - Xprop = data_from_opener.drop(columns=columns_to_drop + [self._treated_col]) - if self._propensity_fit_cols is not None: - Xprop = Xprop[self._propensity_fit_cols] - Xprop = Xprop.to_numpy().astype("float64") - else: - Xprop = None - - # If WebDisco is used without propensity treated column does not exist - if self._treated_col is not None: - treated = ( - data_from_opener[self._treated_col] - .to_numpy() - .astype("float64") - .reshape((-1, 1)) - ) - else: - treated = None - - return (X, y, treated, Xprop) - def compute_X_y_and_propensity_weights(self, data_from_opener, shared_state): """Build appropriate X, y and weights from raw output of opener. @@ -731,26 +615,26 @@ def compute_X_y_and_propensity_weights(self, data_from_opener, shared_state): Returns ------- tuple - _description_ + X input to the Cox model, y target of Cox model, weights propensity weights """ - X, y, treated, Xprop = self.build_X_y(data_from_opener, shared_state) - if self._propensity_model is not None: - assert ( - treated is not None - ), f"""If you are using a propensity model the {self._treated_col} (Treated) - column should be available""" - assert np.all( - np.in1d(np.unique(treated.astype("uint8"))[0], [0, 1]) - ), "The treated column should have all its values in set([0, 1])" - Xprop = torch.from_numpy(Xprop) - with torch.no_grad(): - propensity_scores = self._propensity_model(Xprop) - - propensity_scores = propensity_scores.detach().numpy() - # We robustify the division - weights = treated * 1.0 / np.maximum(propensity_scores, self._tol) + ( - 1 - treated - ) * 1.0 / (np.maximum(1.0 - propensity_scores, self._tol)) - else: - weights = np.ones((X.shape[0], 1)) + X, y, treated, Xprop, self.global_moments = build_X_y_function( + data_from_opener, + self._event_col, + self._duration_col, + self._treated_col, + self._target_cols, + self._standardize_data, + self._propensity_model, + self._cox_fit_cols, + self._propensity_fit_cols, + self._tol, + self._training_strategy, + shared_state=shared_state, + global_moments={} + if not hasattr(self, "global_moments") + else self.global_moments, + ) + X, y, weights = compute_X_y_and_propensity_weights_function( + X, y, treated, Xprop, self._propensity_model, self._tol + ) return X, y, weights diff --git a/fedeca/strategies/fed_kaplan.py b/fedeca/strategies/fed_kaplan.py new file mode 100644 index 00000000..4c02ed90 --- /dev/null +++ b/fedeca/strategies/fed_kaplan.py @@ -0,0 +1,257 @@ +"""Compute federated Kaplan-Meier estimates.""" +import pickle as pk +from pathlib import Path +from typing import Any, List, Optional, Union + +import numpy as np +import pandas as pd +from substrafl import ComputePlanBuilder +from substrafl.evaluation_strategy import EvaluationStrategy +from substrafl.nodes import AggregationNodeProtocol, TrainDataNodeProtocol +from substrafl.remote import remote, remote_data +from torch import nn + +from fedeca.utils.survival_utils import ( + aggregate_events_statistics, + build_X_y_function, + compute_events_statistics, + compute_X_y_and_propensity_weights_function, + km_curve, +) + + +class FedKaplan(ComputePlanBuilder): + """Instantiate a federated version of Kaplan Meier estimates. + + Parameters + ---------- + ComputePlanBuilder : _type_ + _description_ + """ + + def __init__( + self, + duration_col: str, + event_col: str, + treated_col: str, + client_identifier: str, + propensity_model: Union[None, nn.Module] = None, + tol: float = 1e-16, + ): + """Implement a federated version of Kaplan Meier estimates. + + This code is an adaptation of a previous implementation by Constance Beguier. + + Parameters + ---------- + treated_col : Union[None, str], optional + The column describing the treatment, by default None + propensity_model : Union[None, nn.Module], optional + _description_, by default None + """ + super().__init__() + assert not ( + (treated_col is None) and (propensity_model is not None) + ), "if propensity model is provided, treatment_col should be provided as well" + self._duration_col = duration_col + self._event_col = event_col + self._treated_col = treated_col + self._propensity_model = propensity_model + self._client_identifier = client_identifier + self._target_cols = [self._duration_col, self._event_col] + self._tol = tol + self.statistics_result = None + self.kwargs["duration_col"] = duration_col + self.kwargs["event_col"] = event_col + self.kwargs["treated_col"] = treated_col + self.kwargs["propensity_model"] = propensity_model + self.kwargs["client_identifier"] = client_identifier + self.kwargs["tol"] = tol + + def build_compute_plan( + self, + train_data_nodes: Optional[List[TrainDataNodeProtocol]], + aggregation_node: Optional[List[AggregationNodeProtocol]], + evaluation_strategy: Optional[EvaluationStrategy], + num_rounds: Optional[int], + clean_models: Optional[bool] = True, + ): + """Build the computation plan. + + Parameters + ---------- + train_data_nodes : Optional[List[TrainDataNodeProtocol]] + _description_ + aggregation_node : Optional[List[AggregationNodeProtocol]] + _description_ + evaluation_strategy : Optional[EvaluationStrategy] + _description_ + num_rounds : Optional[int] + _description_ + clean_models : Optional[bool], optional + _description_, by default True + """ + del num_rounds + del evaluation_strategy + del clean_models + shared_states = [] + for node in train_data_nodes: + # define composite tasks (do not submit yet) + # for each composite task give description of + # Algo instead of a key for an algo + _, next_shared_state = node.update_states( + self.compute_events_statistics( + node.data_sample_keys, + shared_state=None, + _algo_name="Compute Events Statistics", + ), + local_state=None, + round_idx=0, + authorized_ids=set([node.organization_id]), + aggregation_id=aggregation_node.organization_id, + clean_models=False, + ) + # keep the states in a list: one/organization + shared_states.append(next_shared_state) + + aggregation_node.update_states( + self.compute_agg_km_curve( + shared_states=shared_states, + _algo_name="Aggregate Events Statistics", + ), + round_idx=0, + authorized_ids=set( + train_data_node.organization_id for train_data_node in train_data_nodes + ), + clean_models=False, + ) + + @remote_data + def compute_events_statistics( + self, + data_from_opener: pd.DataFrame, + shared_state=None, + ): + """Compute events statistics for a subset of data. + + Parameters + ---------- + datasamples : _type_ + _description_ + shared_state : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ + del shared_state + # we only use survival times + propensity_cols = [ + col + for col in data_from_opener.columns + if col + not in [ + self._duration_col, + self._event_col, + self._treated_col, + self._client_identifier, + ] + ] + + X, y, treated, Xprop, _ = build_X_y_function( + data_from_opener, + self._event_col, + self._duration_col, + self._treated_col, + self._target_cols, + False, + self._propensity_model, + None, + propensity_cols, + self._tol, + "iptw", + ) + # X contains only the treatment column (strategy == iptw) + + X, _, weights = compute_X_y_and_propensity_weights_function( + X, y, treated, Xprop, self._propensity_model, self._tol + ) + + # TODO actually use weights + # del weights + # retrieve times and events + times = np.abs(y) + events = y >= 0 + assert np.allclose(events, data_from_opener[self._event_col].values) + treated = treated.astype(bool).flatten() + + return { + "treated": compute_events_statistics( + times[treated], events[treated], weights[treated] + ), + "untreated": compute_events_statistics( + times[~treated], events[~treated], weights[~treated] + ), + } + + @remote + def compute_agg_km_curve(self, shared_states): + """Compute the aggregated Kaplan-Meier curve. + + Parameters + ---------- + shared_states : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ + treated_untreated_tnd_agg = { + "treated": aggregate_events_statistics( + [sh["treated"] for sh in shared_states] + ), + "untreated": aggregate_events_statistics( + [sh["untreated"] for sh in shared_states] + ), + } + + return { + "treated": km_curve(*treated_untreated_tnd_agg["treated"]), + "untreated": km_curve(*treated_untreated_tnd_agg["untreated"]), + } + + def save_local_state(self, path: Path): + """Save the object on the disk. + + Should be used only by the backend, to define the local_state. + + Parameters + ---------- + path : Path + Where to save the object. + """ + with open(path, "wb") as file: + pk.dump(self.statistics_result, file) + + def load_local_state(self, path: Path) -> Any: + """Load the object from the disk. + + Should be used only by the backend, to define the local_state. + + Parameters + ---------- + path : Path + Where to find the object. + + Returns + ------- + Any + Previously saved instance. + """ + with open(path, "rb") as file: + self.statistics_result = pk.load(file) + return self diff --git a/fedeca/strategies/fed_smd.py b/fedeca/strategies/fed_smd.py new file mode 100644 index 00000000..8a26aef1 --- /dev/null +++ b/fedeca/strategies/fed_smd.py @@ -0,0 +1,296 @@ +"""Compute SMD for weighted and unweighted data in FL.""" +import pickle as pk +from pathlib import Path +from typing import Any, List, Optional + +import numpy as np +import pandas as pd +import torch +from substrafl import ComputePlanBuilder +from substrafl.evaluation_strategy import EvaluationStrategy +from substrafl.nodes import AggregationNodeProtocol, TrainDataNodeProtocol +from substrafl.remote import remote, remote_data + +from fedeca.utils.moments_utils import compute_global_moments, compute_uncentered_moment +from fedeca.utils.survival_utils import ( + build_X_y_function, + compute_X_y_and_propensity_weights_function, +) + + +class FedSMD(ComputePlanBuilder): + """Compute SMD for weighted and unweighted data in FL. + + Parameters + ---------- + ComputePlanBuilder : ComputePlanBuilder + Analytics strategy. + """ + + def __init__( + self, + duration_col: str, + event_col: str, + treated_col: str, + propensity_model: torch.nn.Module, + client_identifier: str, + tol: float = 1e-16, + ): + """Initialize FedSMD strategy. + + This class computes weighted SMD. + + Parameters + ---------- + treated_col : Union[None, str], optional + The column describing the treatment, by default None + duration_col : str + The column describing the duration of the event, by default None + + propensity_model : Union[None, nn.Module], optional + _description_, by default None + """ + super().__init__() + + self._duration_col = duration_col + self._event_col = event_col + self._treated_col = treated_col + self._target_cols = [self._duration_col, self._event_col] + self._propensity_model = propensity_model + self._propensity_fit_cols = None + self._client_identifier = client_identifier + self._tol = tol + self.statistics_result = None + + # Populating kwargs for reinstatiation + self.kwargs["duration_col"] = duration_col + self.kwargs["event_col"] = event_col + self.kwargs["treated_col"] = treated_col + self.kwargs["propensity_model"] = propensity_model + self.kwargs["client_identifier"] = client_identifier + self.kwargs["tol"] = tol + + def build_compute_plan( + self, + train_data_nodes: Optional[List[TrainDataNodeProtocol]], + aggregation_node: Optional[List[AggregationNodeProtocol]], + evaluation_strategy: Optional[EvaluationStrategy], + num_rounds: Optional[int], + clean_models: Optional[bool] = True, + ): + """Build the computation plan. + + Parameters + ---------- + train_data_nodes : Optional[List[TrainDataNodeProtocol]] + _description_ + aggregation_node : Optional[List[AggregationNodeProtocol]] + _description_ + evaluation_strategy : Optional[EvaluationStrategy] + _description_ + num_rounds : Optional[int] + _description_ + clean_models : Optional[bool], optional + _description_, by default True + """ + del num_rounds + del evaluation_strategy + del clean_models + shared_states = [] + for node in train_data_nodes: + # define composite tasks (do not submit yet) + # for each composite task give description of + # Algo instead of a key for an algo + _, next_shared_state = node.update_states( + self.compute_local_moments_per_group( + node.data_sample_keys, + shared_state=None, + _algo_name="Compute local moments per group", + ), + local_state=None, + round_idx=0, + authorized_ids=set([node.organization_id]), + aggregation_id=aggregation_node.organization_id, + clean_models=False, + ) + # keep the states in a list: one/organization + shared_states.append(next_shared_state) + + aggregation_node.update_states( + self.compute_smd( + shared_states=shared_states, + _algo_name="compute smd for weighted and unweighted data", + ), + round_idx=0, + authorized_ids=set( + train_data_node.organization_id for train_data_node in train_data_nodes + ), + clean_models=False, + ) + + @remote_data + def compute_local_moments_per_group( + self, + data_from_opener, + shared_state=None, + ): + """Compute events statistics for a subset of data. + + Parameters + ---------- + data_from_opener: pd.DataFrame + Data to compute statistics on. + shared_state: list[pd.DataFrame] + List of shared states. + + Returns + ------- + dict + Method output or placeholder thereof. + """ + del shared_state + # we only use survival times + propensity_cols = [ + col + for col in data_from_opener.columns + if col + not in [ + self._duration_col, + self._event_col, + self._treated_col, + self._client_identifier, + ] + ] + + X, y, treated, Xprop, _ = build_X_y_function( + data_from_opener, + self._event_col, + self._duration_col, + self._treated_col, + self._target_cols, + False, + self._propensity_model, + None, + propensity_cols, + self._tol, + "iptw", + ) + # X contains only the treatment column (strategy == iptw) + + X, _, weights = compute_X_y_and_propensity_weights_function( + X, y, treated, Xprop, self._propensity_model, self._tol + ) + + # X contains only the treatment column (strategy == iptw) + # we use Xprop which contain all propensity columns, which + # are the only ones we are interested in + raw_data = pd.DataFrame(Xprop, columns=propensity_cols) + weighted_data = pd.DataFrame( + np.multiply(weights, Xprop), columns=propensity_cols + ) + results = {} + for treatment in [0, 1]: + mask_treatment = treated == treatment + res_name = "treated" if treatment else "untreated" + results[res_name] = {} + results[res_name]["weighted"] = { + f"moment{k}": compute_uncentered_moment( + weighted_data[mask_treatment], k + ) + for k in range(1, 3) + } + results[res_name]["unweighted"] = { + f"moment{k}": compute_uncentered_moment(raw_data[mask_treatment], k) + for k in range(1, 3) + } + results[res_name]["unweighted"]["n_samples"] = results[res_name][ + "weighted" + ]["n_samples"] = ( + raw_data[mask_treatment].select_dtypes(include=np.number).count() + ) + return results + + @remote + def compute_smd( + self, + shared_states, + ): + """Compute Kaplan-Meier curve for a subset of data. + + Parameters + ---------- + shared_states: list + List of shared states. + + Returns + ------- + dict + Method output or placeholder thereof. + """ + + def std_mean_differences(x, y): + """Compute standardized mean differences.""" + means_x = x["global_uncentered_moment_1"] + # we match nump std with 0 ddof contrary to standarization for Cox + stds_x = np.sqrt(x["global_centered_moment_2"] + self._tol) + + means_y = y["global_uncentered_moment_1"] + # we match nump std with 0 ddof contrary to standarization for Cox + stds_y = np.sqrt(y["global_centered_moment_2"] + self._tol) + + smd_df = means_x.subtract(means_y).div( + stds_x.pow(2).add(stds_y.pow(2)).div(2).pow(0.5) + ) + return smd_df + + treated_raw = compute_global_moments( + [shared_state["treated"]["unweighted"] for shared_state in shared_states] + ) + untreated_raw = compute_global_moments( + [shared_state["untreated"]["unweighted"] for shared_state in shared_states] + ) + + smd_raw = std_mean_differences(treated_raw, untreated_raw) + + treated_weighted = compute_global_moments( + [shared_state["treated"]["weighted"] for shared_state in shared_states] + ) + untreated_weighted = compute_global_moments( + [shared_state["untreated"]["weighted"] for shared_state in shared_states] + ) + + smd_weighted = std_mean_differences(treated_weighted, untreated_weighted) + + return {"weighted_smd": smd_weighted, "unweighted_smd": smd_raw} + + def save_local_state(self, path: Path): + """Save the object on the disk. + + Should be used only by the backend, to define the local_state. + + Parameters + ---------- + path : Path + Where to save the object. + """ + with open(path, "wb") as file: + pk.dump(self.statistics_result, file) + + def load_local_state(self, path: Path) -> Any: + """Load the object from the disk. + + Should be used only by the backend, to define the local_state. + + Parameters + ---------- + path : Path + Where to find the object. + + Returns + ------- + Any + Previously saved instance. + """ + with open(path, "rb") as file: + self.statistics_result = pk.load(file) + return self diff --git a/fedeca/strategies/webdisco.py b/fedeca/strategies/webdisco.py index 4e630971..1e509e78 100644 --- a/fedeca/strategies/webdisco.py +++ b/fedeca/strategies/webdisco.py @@ -16,7 +16,7 @@ # from substrafl.schemas import WebDiscoSharedState from substrafl.strategies.strategy import Strategy -from fedeca.utils.moments_utils import aggregation_mean, compute_centered_moment +from fedeca.utils.moments_utils import compute_global_moments class StrategyName(str, Enum): @@ -633,29 +633,7 @@ def aggregate_moments(self, shared_states): Global results to be shared with train nodes via shared_state. """ # aggregate the moments. - - tot_uncentered_moments = [ - aggregation_mean( - [s[f"moment{k}"] for s in shared_states], - [s["n_samples"] for s in shared_states], - ) - for k in range(1, 2 + 1) - ] - n_samples = sum([s["n_samples"].iloc[0] for s in shared_states]) - results = { - f"global_centered_moment_{k}": compute_centered_moment( - tot_uncentered_moments[:k] - ) - for k in range(1, 2 + 1) - } - results.update( - { - f"global_uncentered_moment_{k+1}": moment - for k, moment in enumerate(tot_uncentered_moments) - } - ) - results.update({"total_n_samples": n_samples}) - return results + return compute_global_moments(shared_states) def perform_evaluation( self, diff --git a/fedeca/tests/strategies/test_km.py b/fedeca/tests/strategies/test_km.py new file mode 100644 index 00000000..075b801b --- /dev/null +++ b/fedeca/tests/strategies/test_km.py @@ -0,0 +1,145 @@ +"""Module testing substraFL moments strategy.""" +import os +import subprocess +from pathlib import Path + +import git +import numpy as np +import torch +from lifelines import KaplanMeierFitter as KMF +from substrafl.dependency import Dependency +from substrafl.experiment import execute_experiment +from substrafl.model_loading import download_aggregate_shared_state +from substrafl.nodes import AggregationNode +from torch import nn + +import fedeca +from fedeca.fedeca_core import LogisticRegressionTorch +from fedeca.strategies.fed_kaplan import FedKaplan +from fedeca.tests.common import TestTempDir +from fedeca.utils.data_utils import split_dataframe_across_clients +from fedeca.utils.survival_utils import CoxData + + +class TestKM(TestTempDir): + """Test substrafl computation of KM. + + Tests the FL computation of KM is the same as in pandas-pooled version + """ + + def setUp(self, backend_type="subprocess", ndim=10) -> None: + """Set up the quantities needed for the tests.""" + # Let's generate 1000 data samples with 10 covariates + data = CoxData(seed=42, n_samples=1000, ndim=ndim, percent_ties=0.2) + self.df = data.generate_dataframe() + + # We remove the true propensity score + self.df = self.df.drop(columns=["propensity_scores"], axis=1) + + self.clients, self.train_data_nodes, _, _, _ = split_dataframe_across_clients( + self.df, + n_clients=4, + split_method="split_control_over_centers", + split_method_kwargs={"treatment_info": "treatment"}, + data_path=Path(self.test_dir) / "data", + backend_type=backend_type, + ) + kwargs_agg_node = {"organization_id": self.train_data_nodes[0].organization_id} + self.aggregation_node = AggregationNode(**kwargs_agg_node) + # Packaging the right dependencies + + fedeca_path = fedeca.__path__[0] + repo_folder = Path( + git.Repo(fedeca_path, search_parent_directories=True).working_dir + ).resolve() + wheel_folder = repo_folder / "temp" + os.makedirs(wheel_folder, exist_ok=True) + for stale_wheel in wheel_folder.glob("fedeca*.whl"): + stale_wheel.unlink() + process = subprocess.Popen( + f"python -m build --wheel --outdir {wheel_folder} {repo_folder}", + shell=True, + stdout=subprocess.PIPE, + ) + process.wait() + assert process.returncode == 0, "Failed to build the wheel" + self.wheel_path = next(wheel_folder.glob("fedeca*.whl")) + self.ds_client = self.clients[self.train_data_nodes[0].organization_id] + self.propensity_model = LogisticRegressionTorch(ndim=ndim) + + self.propensity_model.fc1.weight.data = nn.parameter.Parameter( + torch.randn( + size=self.propensity_model.fc1.weight.data.shape, dtype=torch.float64 + ) + ) + self.propensity_model.fc1.bias.data = nn.parameter.Parameter( + torch.randn( + size=self.propensity_model.fc1.bias.data.shape, dtype=torch.float64 + ) + ) + + def test_end_to_end(self): + """Compare a FL and pooled computation of Moments. + + The data are the tcga ones. + """ + # Get fl_results. + strategy = FedKaplan( + treated_col="treatment", + duration_col="time", + event_col="event", + propensity_model=self.propensity_model, + client_identifier="center", + ) + + compute_plan = execute_experiment( + client=self.ds_client, + strategy=strategy, + train_data_nodes=self.train_data_nodes, + evaluation_strategy=None, + aggregation_node=self.aggregation_node, + num_rounds=1, + experiment_folder=str(Path(self.test_dir) / "experiment_summaries"), + dependencies=Dependency( + local_installable_dependencies=[Path(self.wheel_path)] + ), + ) + + fl_results = download_aggregate_shared_state( + client=self.ds_client, + compute_plan_key=compute_plan.key, + round_idx=0, + ) + + X = self.df.drop(columns=["time", "event", "treatment"], axis=1) + + Xprop = torch.from_numpy(X.values).type(self.propensity_model.fc1.weight.dtype) + with torch.no_grad(): + self.propensity_model.eval() + propensity_scores = self.propensity_model(Xprop) + + propensity_scores = propensity_scores.detach().numpy().flatten() + weights = self.df["treatment"] * 1.0 / propensity_scores + ( + 1 - self.df["treatment"] + ) * 1.0 / (1.0 - propensity_scores) + + treatments = [1, 0] + # TODO test with weights + kms = [ + KMF().fit( + durations=self.df.loc[self.df["treatment"] == t]["time"], + event_observed=self.df.loc[self.df["treatment"] == t]["event"], + weights=weights.loc[self.df["treatment"] == t], + ) + for t in treatments + ] + s_gts = [kmf.survival_function_["KM_estimate"].to_numpy() for kmf in kms] + grid_gts = [kmf.survival_function_.index.to_numpy() for kmf in kms] + + fl_grid_treated, fl_s_treated, _ = fl_results["treated"] + fl_grid_untreated, fl_s_untreated, _ = fl_results["untreated"] + + assert np.allclose(fl_grid_treated, grid_gts[0], rtol=1e-2) + assert np.allclose(fl_s_treated, s_gts[0], rtol=1e-2) + assert np.allclose(fl_grid_untreated, grid_gts[1], rtol=1e-2) + assert np.allclose(fl_s_untreated, s_gts[1], rtol=1e-2) diff --git a/fedeca/tests/strategies/test_smd.py b/fedeca/tests/strategies/test_smd.py new file mode 100644 index 00000000..1b9c7c0e --- /dev/null +++ b/fedeca/tests/strategies/test_smd.py @@ -0,0 +1,152 @@ +"""Module testing substraFL moments strategy.""" +import os +import subprocess +from pathlib import Path + +import git +import numpy as np +import pandas as pd +import torch +from substrafl.dependency import Dependency +from substrafl.experiment import execute_experiment +from substrafl.model_loading import download_aggregate_shared_state +from substrafl.nodes import AggregationNode +from torch import nn + +import fedeca +from fedeca.fedeca_core import LogisticRegressionTorch +from fedeca.metrics.metrics import standardized_mean_diff +from fedeca.strategies.fed_smd import FedSMD +from fedeca.tests.common import TestTempDir +from fedeca.utils.data_utils import split_dataframe_across_clients +from fedeca.utils.survival_utils import CoxData + + +class TestSMD(TestTempDir): + """Test substrafl computation of SMD. + + Tests the FL computation of SMD is the same as in pandas-pooled version + """ + + def setUp(self, backend_type="subprocess", ndim=10) -> None: + """Set up the quantities needed for the tests.""" + # Let's generate 1000 data samples with 10 covariates + data = CoxData(seed=42, n_samples=1000, ndim=ndim) + self.df = data.generate_dataframe() + + # We remove the true propensity score + self.df = self.df.drop(columns=["propensity_scores"], axis=1) + + self.clients, self.train_data_nodes, _, _, _ = split_dataframe_across_clients( + self.df, + n_clients=4, + split_method="split_control_over_centers", + split_method_kwargs={"treatment_info": "treatment"}, + data_path=Path(self.test_dir) / "data", + backend_type=backend_type, + ) + kwargs_agg_node = {"organization_id": self.train_data_nodes[0].organization_id} + self.aggregation_node = AggregationNode(**kwargs_agg_node) + # Packaging the right dependencies + + fedeca_path = fedeca.__path__[0] + repo_folder = Path( + git.Repo(fedeca_path, search_parent_directories=True).working_dir + ).resolve() + wheel_folder = repo_folder / "temp" + os.makedirs(wheel_folder, exist_ok=True) + for stale_wheel in wheel_folder.glob("fedeca*.whl"): + stale_wheel.unlink() + process = subprocess.Popen( + f"python -m build --wheel --outdir {wheel_folder} {repo_folder}", + shell=True, + stdout=subprocess.PIPE, + ) + process.wait() + assert process.returncode == 0, "Failed to build the wheel" + self.wheel_path = next(wheel_folder.glob("fedeca*.whl")) + self.ds_client = self.clients[self.train_data_nodes[0].organization_id] + self.propensity_model = LogisticRegressionTorch(ndim=ndim) + + self.propensity_model.fc1.weight.data = nn.parameter.Parameter( + torch.randn( + size=self.propensity_model.fc1.weight.data.shape, dtype=torch.float64 + ) + ) + self.propensity_model.fc1.bias.data = nn.parameter.Parameter( + torch.randn( + size=self.propensity_model.fc1.bias.data.shape, dtype=torch.float64 + ) + ) + + def test_end_to_end(self): + """Compare a FL and pooled computation of Moments. + + The data are the tcga ones. + """ + # Get fl_results. + strategy = FedSMD( + treated_col="treatment", + duration_col="time", + event_col="event", + propensity_model=self.propensity_model, + client_identifier="center", + ) + + compute_plan = execute_experiment( + client=self.ds_client, + strategy=strategy, + train_data_nodes=self.train_data_nodes, + evaluation_strategy=None, + aggregation_node=self.aggregation_node, + num_rounds=1, + experiment_folder=str(Path(self.test_dir) / "experiment_summaries"), + dependencies=Dependency( + local_installable_dependencies=[Path(self.wheel_path)] + ), + ) + + fl_results = download_aggregate_shared_state( + client=self.ds_client, + compute_plan_key=compute_plan.key, + round_idx=0, + ) + + assert not fl_results["weighted_smd"].equals(fl_results["unweighted_smd"]) + X = self.df.drop(columns=["time", "event", "treatment"], axis=1) + covariates = X.columns + Xprop = torch.from_numpy(X.values).type(self.propensity_model.fc1.weight.dtype) + with torch.no_grad(): + self.propensity_model.eval() + propensity_scores = self.propensity_model(Xprop) + + propensity_scores = propensity_scores.detach().numpy().flatten() + weights = self.df["treatment"] * 1.0 / propensity_scores + ( + 1 - self.df["treatment"] + ) * 1.0 / (1.0 - propensity_scores) + weights = weights.values + + X_weighted = (Xprop * weights[:, np.newaxis]).numpy() + X_weighted_df = pd.DataFrame(X_weighted, columns=covariates) + X_df = pd.DataFrame(Xprop.numpy(), columns=covariates) + + standardized_mean_diff_pooled_weighted = standardized_mean_diff( + X_weighted_df, + self.df["treatment"] == 1, + ).div(100.0) + standardized_mean_diff_pooled_unweighted = standardized_mean_diff( + X_df, + self.df["treatment"] == 1, + ).div(100.0) + + # We check equality of FL computation and pooled results + pd.testing.assert_series_equal( + standardized_mean_diff_pooled_weighted, + fl_results["weighted_smd"], + rtol=1e-2, + ) + pd.testing.assert_series_equal( + standardized_mean_diff_pooled_unweighted, + fl_results["unweighted_smd"], + rtol=1e-2, + ) diff --git a/fedeca/tests/test_km.py b/fedeca/tests/test_km.py new file mode 100644 index 00000000..3926ec2e --- /dev/null +++ b/fedeca/tests/test_km.py @@ -0,0 +1,43 @@ +"""Tests for the survival utils related to Kaplan Meier.""" +import numpy as np +from lifelines import KaplanMeierFitter + +from fedeca.utils.survival_utils import compute_events_statistics, km_curve + + +def test_compute_events_statistics(): + """Test the computation of events statistics.""" + times = np.array([0.0, 3.2, 3.2, 5.5, 5.5, 10.0]) + events = np.array([True, True, False, True, True, False]) + unique_times_gt = np.array([0.0, 3.2, 5.5, 10.0]) + num_death_at_times_gt = np.array([1, 1, 2, 0]) + num_at_risk_at_times_gt = np.array([6, 5, 3, 1]) + + # Check that the result is correct + t, n, d = compute_events_statistics(times, events) + assert np.allclose(t, unique_times_gt) + assert np.allclose(n, num_at_risk_at_times_gt) + assert np.allclose(d, num_death_at_times_gt) + # Check that the result is invariant by permutation + p = np.random.permutation(times.size) + t_p, n_p, d_p = compute_events_statistics(times[p], events[p]) + for (a, b) in zip([t_p, n_p, d_p], [t, n, d]): + assert np.allclose(a, b) + + +def test_km_curve(): + """Test the computation of the Kaplan Meier curve.""" + rng = np.random.RandomState(42) + num_samples = 100 + times = rng.randint(0, high=21, size=(num_samples,)) + events = rng.rand(num_samples) > 0.5 + # Get KM curve + t, n, d = compute_events_statistics(times, events) + grid, s, _ = km_curve(t, n, d, tmax=20) + # Compute the one from lifelines + kmf = KaplanMeierFitter() + kmf.fit(times, event_observed=events) + s_gt = kmf.survival_function_["KM_estimate"].to_numpy() + grid_gt = kmf.survival_function_.index.to_numpy() + assert np.allclose(grid_gt, grid) + assert np.allclose(s_gt, s) diff --git a/fedeca/utils/data_utils.py b/fedeca/utils/data_utils.py index bd2735d0..efe09ed7 100644 --- a/fedeca/utils/data_utils.py +++ b/fedeca/utils/data_utils.py @@ -193,7 +193,6 @@ def split_dataframe_across_clients( assert len(all_indices) == len(df.index) assert set(all_indices) == set(range(len(df.index))) dfs = [] - for i in range(n_clients): os.makedirs(data_path / f"center{i}", exist_ok=True) cdf = copy.deepcopy(df.iloc[clients_indices_list[i]]) diff --git a/fedeca/utils/moments_utils.py b/fedeca/utils/moments_utils.py index 138d5c50..ed47d4d0 100644 --- a/fedeca/utils/moments_utils.py +++ b/fedeca/utils/moments_utils.py @@ -98,8 +98,8 @@ def aggregation_mean(local_means: List[Any], n_local_samples: List[int]): Any Aggregated mean. Same type of the local means """ - tot_samples = np.copy(n_local_samples[0]) - tot_mean = np.copy(local_means[0]) + tot_samples = np.nan_to_num(np.copy(n_local_samples[0]), nan=0, copy=False) + tot_mean = np.nan_to_num(np.copy(local_means[0]), nan=0, copy=False) for mean, n_sample in zip(local_means[1:], n_local_samples[1:]): mean = np.nan_to_num(mean, nan=0, copy=False) tot_mean *= tot_samples / (tot_samples + n_sample) @@ -107,3 +107,40 @@ def aggregation_mean(local_means: List[Any], n_local_samples: List[int]): tot_samples += n_sample return tot_mean + + +def compute_global_moments(shared_states): + """Aggregate local moments. + + Parameters + ---------- + shared_states : list + list of outputs from compute_uncentered_moment. + + Returns + ------- + dict + The results of the aggregation with both centered and uncentered moments. + """ + tot_uncentered_moments = [ + aggregation_mean( + [s[f"moment{k}"] for s in shared_states], + [s["n_samples"] for s in shared_states], + ) + for k in range(1, 2 + 1) + ] + n_samples = sum([s["n_samples"].iloc[0] for s in shared_states]) + results = { + f"global_centered_moment_{k}": compute_centered_moment( + tot_uncentered_moments[:k] + ) + for k in range(1, 2 + 1) + } + results.update( + { + f"global_uncentered_moment_{k+1}": moment + for k, moment in enumerate(tot_uncentered_moments) + } + ) + results.update({"total_n_samples": n_samples}) + return results diff --git a/fedeca/utils/substrafl_utils.py b/fedeca/utils/substrafl_utils.py index da6c4d5f..42156aeb 100644 --- a/fedeca/utils/substrafl_utils.py +++ b/fedeca/utils/substrafl_utils.py @@ -364,6 +364,13 @@ def reset_experiment(self): self.performances_strategies = [] self.train_data_nodes = None self.test_data_nodes = None + # We need to avoid persistence of DB in between runs, this is an obscure + # hack but it's working + if hasattr(self.ds_client, "_backend"): + database = self.ds_client._backend._db._db._data + if len(database.keys()) > 1: + for k in list(database.keys()): + database.pop(k) def get_outmodel_function( diff --git a/fedeca/utils/survival_utils.py b/fedeca/utils/survival_utils.py index 97aa12a0..14f7f1b3 100644 --- a/fedeca/utils/survival_utils.py +++ b/fedeca/utils/survival_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations import copy +from math import sqrt from typing import Final, Literal, Optional, Protocol import numpy as np @@ -9,6 +10,7 @@ import pandas as pd import torch from numpy.typing import NDArray +from pandas.api.types import is_numeric_dtype from scipy import stats from scipy.linalg import toeplitz from sklearn.base import BaseEstimator @@ -412,6 +414,15 @@ def generate_data( if not reached: raise ValueError("This should not happen, lower percent_ties") times = times.reshape((-1)) + # 0-1 scale + times /= float(nbins) + # With this Kbins discretizer, times start always at 0. + # 0, 1, 2, 3, ... + # there should be no time at exactly 0. otherwise lifelines + # (rightfully) will act in a weird way. See the birth outer join in + # https://github.com/CamDavidsonPilon/lifelines/blob/4377caf5a6224941ee3ab34c413ad668d4173274/lifelines/utils/__init__.py#L567 + # therefore we add a small quantity to every time + times += np.random.uniform(1.0 / nbins, 1.0, size=1) else: raise ValueError("Choose a larger number of ties") @@ -1295,6 +1306,9 @@ def robust_sandwich_variance_pooled( model. The sandwich variance estimator is a robust estimator of the variance which accounts for the lack of dependence between the samples due to the introduction of weights for example. + + Parameters + ---------- X_norm : np.ndarray or torch.Tensor Input feature matrix of shape (n_samples, n_features). y : np.ndarray or torch.Tensor @@ -1307,6 +1321,11 @@ def robust_sandwich_variance_pooled( Weights associated with each sample, with shape (n_samples,) scaled_variance_matrix : np.ndarray or torch.Tensor Classical scaled variance of the Cox model estimator. + + Returns + ------- + np.ndarray + The robust sandwich variance estimator. """ n_samples, n_features = X_norm.shape @@ -1352,3 +1371,420 @@ def robust_sandwich_variance_pooled( delta_betas = score_residuals.dot(scaled_variance_matrix) tested_var = delta_betas.T.dot(delta_betas) return np.sqrt(np.diag(tested_var)) + + +def km_curve(t, n, d, tmax=None): + """Compute Kaplan-Meier (KM) curve. + + This function is typically used in conjunction with + `compute_events_statistics`. Note that the variance is computed + based on Greenwood's formula, not its exponential variant (see refs). + + Parameters + ---------- + t : np.array + Array containing the unique times of events, sorted in ascending order + n : np.array + Array containing the number of individuals at risk at each + corresponding time `t` + d : np.array + Array containing the number of individuals with an event (death) at + each corresponding time `t` + tmax : int, optional + Number of grid points, Default to the number of unique events + 1. + + Returns + ------- + tuple + Tuple of length 3 containing: + - `grid`: 1D array containing the time points at which the survival + curve is evaluated. It is `np.arange(0, t_max+1)` + - `s`: 1D array with the value of the survival function as obtained + by the Kaplan-Meier formula. + - `var_s`: `np.array` containing the variance of the Kaplan-Meier curve + at each point of `grid` + + Examples + -------- + .. code-block:: python + :linenos: + # Define times and events + times = np.random.randint(0, 3000, size=(10,)) + events = np.random.rand(10) > 0.5 + # Compute events statistics + t, n, d = compute_events_statistics(times, events) + # Get Kaplan-Meier curve + grid, s, var_s = km_curve(t, n, d) + + # Plot KM curve + plt.figure() + plt.plot(grid, s) + plt.fill_between(grid, s-np.sqrt(var_s), s+np.sqrt(var_s)) + + References + ---------- + https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator + https://www.math.wustl.edu/~sawyer/handouts/greenwood.pdf + """ + # We compute the grid on which we will want to plot S(t) + if tmax is None: + # Unique events + 0 ("birth") as in lifelines + grid = np.array([0] + t.tolist()) + else: + # Not sure if useful but... + grid = np.linspace(0, t.max(), tmax + 1) + # KM estimator but wo filtering terms out + q = 1.0 - d / n + cprod_q = np.cumprod(q) + + # Same for Greenwood's formula + csum_var = np.cumsum(d / (n * (n - d))) + + # Now we just need for each point of the grid to filter out terms that are + # bigger than them + # we initialize by filtering out everything + s = np.zeros(grid.shape) + var_s = np.zeros(grid.shape) + + # we need, for each element in the grid, the index of the cumprod/cumsum + # it should go to, which would by design filter out the right terms + # to respect KM's formula + mask = grid.reshape(-1, 1) - t.reshape(1, -1) >= 0 # (grid.shape, t.shape) + index_in_cum_vec = np.sum(mask, axis=1) - 1 + # We can now compute the survival function for each point in the grid + # Survival function starts at 1. + s[index_in_cum_vec < 0.0] = 1.0 + s[index_in_cum_vec >= 0.0] = cprod_q[index_in_cum_vec[index_in_cum_vec >= 0]] + + # And now similarly we derive Greenwood + var_s[index_in_cum_vec >= 0] = (s[index_in_cum_vec >= 0] ** 2) * csum_var[ + index_in_cum_vec[index_in_cum_vec >= 0] + ] + return grid, s, var_s + + +def compute_events_statistics(times, events, weights=None): + """Compute unique times, number of individuals at risk at these times, etc. + + Also computes number of events at these times based on the raw list of individual + times and events, for a survival framework. + + The method is vectorized with numpy to ensure fast computations. As a + side-effect, memory consumption scales as + `num_unique_times * num_individuals`. + + Parameters + ---------- + times : np.array + 1D array containing the individual observed times, which are either + censored times or true event times depending on the corresponding value + of the `events` array + events : np.array + 1D array with boolean entries, such that `events[i] == True` if and + only if a true event was observed for individual `i` at time `times[i]` + weights : np.array, optional + Weights from a propensity model. + + Returns + ------- + tuple + Tuple of length 3 containing, in this order: + - `unique_times`: `np.array` containing the unique times of events, + in ascending order + - `num_at_risk_at_times`: `np.array` containing the number of individuals + at risk at each corresponding `unique_times` + - `num_death_at_times`: `np.array` containing the number of individuals + with an event (death) at each corresponding `unique_times` + + Examples + -------- + .. code-block:: python + :linenos: + # Define times and events + times = np.random.randint(0, 3000, size=(10,)) + events = np.random.rand(10) > 0.5 + # Compute events statistics + t, n, d = compute_events_statistics(times, events) + """ + # NB, both censored and uncensored, otherwise impossible to aggregate exactly + unique_times = np.unique(times) + if weights is None: + weights = np.ones_like(times) + + num_death_at_times = np.sum( + weights[events].reshape(1, -1) + * ((unique_times.reshape(-1, 1) - times[events].reshape(1, -1)) == 0), + axis=1, + ) + num_at_risk_at_times = np.sum( + weights.reshape(1, -1) + * ((times.reshape(1, -1) - unique_times.reshape(-1, 1)) >= 0), + axis=1, + ) + return unique_times, num_at_risk_at_times, num_death_at_times + + +def aggregate_events_statistics(list_t_n_d): + """Aggregate (sums) events statistics from different centers. + + Parameters + ---------- + list_t_n_d : tuple + List of event statistics from different centers. Each entry in the list + should follow the output format of `compute_events_statistics` + + Returns + ------- + tuple + Tuple of size 3 containing, in this order: + - `t_agg`: `np.array` containing the unique times of events, + in ascending order + - `n_agg`: `np.array` containing the number of individuals + at risk at each corresponding `t_agg` + - `d_agg`: `np.array` containing the number of individuals + with an event (death) at each corresponding `t_agg` + """ + # Step 1: get the unique times + unique_times_list = [t for (t, _, _) in list_t_n_d if t.size != 0] + t_agg = np.unique(np.concatenate(unique_times_list)) + # Step 2: extend to common grid + n_ext, d_ext = extend_events_to_common_grid(list_t_n_d, t_agg) + # Step 3:sum across centers + n_agg = np.sum(n_ext, axis=0) + d_agg = np.sum(d_ext, axis=0) + return t_agg, n_agg, d_agg + + +def extend_events_to_common_grid(list_t_n_d, t_common): + """Extend a list of heterogeneous times, number of people at risk on common grid. + + This method is an internal utility for `aggregate_events_statistics`. + + Parameters + ---------- + list_t_n_d : List[tuple] + List of tuples, each item in the list being an output of + `compute_events_statistics`, i.e. each tuple is of size 3 and contains + - `t`: 1D array with unique event times, sorted in ascending order + - `n`: 1D array with the number of individuals at risk for each time point + - `d`: 1D array with the number of events for each time point + t_common : np.ndarray + Common grid on which to compute the aggregated times. + + Returns + ------- + tuple + Tuple with 2 values, each being a 2D array of size + `len(list_t_n_d) * t_common.size` + - `n_ext`: number of individuals at risk for each center and time + point in `t_common` + - `d_ext`: number of events for each center and each time point + in `t_common` + """ + num_clients = len(list_t_n_d) + + d_ext = np.zeros((num_clients, t_common.size)) + n_ext = np.zeros((num_clients, t_common.size)) + + for k, (t, n, d) in enumerate(list_t_n_d): + if t.size == 0: + continue + diff = t - t_common.reshape(-1, 1) + diff_abs = np.abs(diff) + # identify times which were in the client vs those not + is_in_subset = diff_abs.min(axis=1) == 0 + not_in_subset = diff_abs.min(axis=1) > 0 + # make a correspondence for those that were in + corr_in_subset = np.zeros(t_common.size, dtype=int) + corr_in_subset[is_in_subset] = diff_abs[is_in_subset].argmin(axis=1) + # for those that were not, but which are inside the grid, we can extend n + diff_relu = np.maximum(diff, 0) + has_match = (diff_relu.max(axis=1) > 0) & not_in_subset + # we rely on the fact that t is ordered! + corr_in_subset[has_match] = np.sum(diff_relu == 0, axis=1)[has_match] + # Get the deaths: 0 if no time, the value if it was inside + d_ext[k, is_in_subset] = d[corr_in_subset[is_in_subset]] + # Get the people at risk for those in subset + n_ext[k, is_in_subset] = n[corr_in_subset[is_in_subset]] + # For those not in subset but with a match, extend in a piecewise constant + # fashion + n_ext[k, has_match] = n[corr_in_subset[has_match]] + return n_ext, d_ext + + +def build_X_y_function( + data_from_opener, + event_col, + duration_col, + treated_col, + target_cols=None, + standardize_data=True, + propensity_model=None, + cox_fit_cols=None, + propensity_fit_cols=None, + tol=1e-16, + training_strategy="webdisco", + shared_state={}, + global_moments={}, +): + """Build the inputs for a propensity model and for a Cox model and y. + + Does that directly on data from opener. + + This function 1. uses the event column to inject the censorship + information present in the duration column (given in absolute values) + in the form of a negative sign producing y. + 2. Drops the raw survival columns (given in target_cols) now that y is built. + 2. Drops some covariates differently for Cox and for the propensity model + according to the training_strategy argument as well as cox and propensity + fit cols. + 4. Standardize the data if standardize_data. It will use either the shared_state + or the global moments variable depending on what is given. If both given + uses shared_state. If not given does not err but does not standardize. + 5. Return the Cox model input X as well as the (unstandardized) input to the + propensity model Xprop if necessary as well as the treated column to be able + to compute the propensity weights. + + Parameters + ---------- + data_from_opener : pd.DataFrame + The output of the opener + shared_state : dict, optional + Outmodel containing global means and stds. + by default {} + + Returns + ------- + tuple + standardized X, signed times, treatment column and unstandardized + propensity model input + """ + # We need y to be in the format (2*event-1)*duration + data_from_opener["time_multiplier"] = [ + 2.0 * e - 1.0 for e in data_from_opener[event_col].tolist() + ] + # No funny business irrespective of the convention used + y = np.abs(data_from_opener[duration_col]) * data_from_opener["time_multiplier"] + y = y.to_numpy().astype("float64") + data_from_opener.drop(columns=["time_multiplier"], inplace=True) + + # TODO very dangerous, to replace by removing client_identifier + # in both cases this SHOULD NOT BE inplace + string_columns = [ + col + for col in data_from_opener.columns + if not (is_numeric_dtype(data_from_opener[col])) + ] + data_from_opener = data_from_opener.drop(columns=string_columns) + + # We drop the targets from X + if target_cols is None: + target_cols = [event_col, duration_col] + columns_to_drop = target_cols + X = data_from_opener.drop(columns=columns_to_drop) + if propensity_model is not None: + assert treated_col is not None + if training_strategy == "iptw": + X = X.loc[:, [treated_col]] + elif training_strategy == "aiptw": + if len(cox_fit_cols) > 0: + X = X.loc[:, [treated_col] + cox_fit_cols] + else: + pass + else: + assert training_strategy == "webdisco" + if len(cox_fit_cols) > 0: + X = X.loc[:, cox_fit_cols] + else: + pass + + # If X is to be standardized we do it + if standardize_data: + if shared_state: + # Careful this shouldn't happen apart from the predict + means = shared_state["global_uncentered_moment_1"] + vars = shared_state["global_centered_moment_2"] + # Careful we need to match pandas and use unbiased estimator + bias_correction = (shared_state["total_n_samples"]) / float( + shared_state["total_n_samples"] - 1 + ) + global_moments = { + "means": means, + "vars": vars, + "bias_correction": bias_correction, + } + stds = vars.transform(lambda x: sqrt(x * bias_correction + tol)) + X = X.sub(means) + X = X.div(stds) + else: + X = X.sub(global_moments["means"]) + stds = global_moments["vars"].transform( + lambda x: sqrt(x * global_moments["bias_correction"] + tol) + ) + X = X.div(stds) + + X = X.to_numpy().astype("float64") + + # If we have a propensity model we need to build X without the targets AND the + # treated column + if propensity_model is not None: + # We do not normalize the data for the propensity model !!! + Xprop = data_from_opener.drop(columns=columns_to_drop + [treated_col]) + if propensity_fit_cols is not None: + Xprop = Xprop[propensity_fit_cols] + Xprop = Xprop.to_numpy().astype("float64") + else: + Xprop = None + + # If WebDisco is used without propensity treated column does not exist + if treated_col is not None: + treated = ( + data_from_opener[treated_col].to_numpy().astype("float64").reshape((-1, 1)) + ) + else: + treated = None + + return (X, y, treated, Xprop, global_moments) + + +def compute_X_y_and_propensity_weights_function( + X, y, treated, Xprop, propensity_model, tol=1e-16 +): + """Build appropriate X, y and weights from raw output of opener. + + Uses the helper function build_X_y and the propensity model to build the + weights. + + Parameters + ---------- + data_from_opener : pd.DataFrame + Raw output from opener + shared_state : dict, optional + Outmodel containing global means and stds, by default {} + + Returns + ------- + tuple + _description_ + """ + if propensity_model is not None: + assert ( + treated is not None + ), """If you are using a propensity model the Treated + column should be available""" + assert np.all( + np.in1d(np.unique(treated.astype("uint8"))[0], [0, 1]) + ), "The treated column should have all its values in set([0, 1])" + Xprop = torch.from_numpy(Xprop) + + with torch.no_grad(): + propensity_scores = propensity_model(Xprop) + + propensity_scores = propensity_scores.detach().numpy() + # We robustify the division + weights = treated * 1.0 / np.maximum(propensity_scores, tol) + ( + 1 - treated + ) * 1.0 / (np.maximum(1.0 - propensity_scores, tol)) + else: + weights = np.ones((X.shape[0], 1)) + return X, y, weights