Skip to content

Commit

Permalink
Jean/fed kaplan (#44)
Browse files Browse the repository at this point in the history
* general architecture of fedkaplan

* respecting naming conventions

* refactoring preprocessing

* passing in my head but not tested

* adding test for KM utils

* add credit

* everything works in my head

* hacking

* some refactoring

* fixing bug

* fixing various stuff

* fixing stuff

* everything passing

* test passing

* trying fixing tests

* linting

* linting

* linting

* linting

* linting

* linting

* linting fedkaplan

* trying to finally fix linting

* linting

* fixing substra stuff in FedKM

* test almost working wo weights

* linting

* linting

* now tests are not passing only because grid is not the same

* tests passing

* weights working

* removing useless comments

* tests should be passing

* removing forgoteen brakpoint
  • Loading branch information
jeandut authored Aug 13, 2024
1 parent cd6729d commit 440dafb
Show file tree
Hide file tree
Showing 11 changed files with 1,403 additions and 169 deletions.
168 changes: 26 additions & 142 deletions fedeca/algorithms/torch_webdisco_algo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Implement webdisco algorithm with Torch."""
import copy
from copy import deepcopy
from math import sqrt
from pathlib import Path
from typing import Any, List, Optional, Union

Expand All @@ -11,7 +10,6 @@
from autograd import elementwise_grad
from autograd import numpy as anp
from lifelines.utils import StepSizer
from pandas.api.types import is_numeric_dtype
from scipy.linalg import norm
from scipy.linalg import solve as spsolve
from substrafl.algorithms.pytorch import weight_manager
Expand All @@ -21,7 +19,11 @@

from fedeca.schemas import WebDiscoAveragedStates, WebDiscoSharedState
from fedeca.utils.moments_utils import compute_uncentered_moment
from fedeca.utils.survival_utils import MockStepSizer
from fedeca.utils.survival_utils import (
MockStepSizer,
build_X_y_function,
compute_X_y_and_propensity_weights_function,
)


class TorchWebDiscoAlgo(TorchAlgo):
Expand Down Expand Up @@ -597,124 +599,6 @@ def summary(self):
summary = super().summary()
return summary

def build_X_y(self, data_from_opener, shared_state={}):
"""Build appropriate X and y times from output of opener.
This function 1. uses the event column to inject the censorship
information present in the duration column (given in absolute values)
in the form of a negative sign.
2. Drop every covariate except treatment if self.strategy == "iptw".
3. Standardize the data if self.standardize_data AND if it receives
an outmodel.
4. Return the (unstandardized) input to the propensity model Xprop if
necessary as well as the treated column to be able to compute the
propensity weights.
Parameters
----------
data_from_opener : pd.DataFrame
The output of the opener
shared_state : dict, optional
Outmodel containing global means and stds.
by default {}
Returns
-------
tuple
standardized X, signed times, treatment column and unstandardized
propensity model input
"""
# We need y to be in the format (2*event-1)*duration
data_from_opener["time_multiplier"] = [
2.0 * e - 1.0 for e in data_from_opener[self._event_col].tolist()
]
# No funny business irrespective of the convention used
y = (
np.abs(data_from_opener[self._duration_col])
* data_from_opener["time_multiplier"]
)
y = y.to_numpy().astype("float64")
data_from_opener = data_from_opener.drop(columns=["time_multiplier"])
# dangerous but we need to do it
string_columns = [
col
for col in data_from_opener.columns
if not (is_numeric_dtype(data_from_opener[col]))
]
data_from_opener = data_from_opener.drop(columns=string_columns)

# We drop the targets from X
columns_to_drop = self._target_cols
X = data_from_opener.drop(columns=columns_to_drop)
if self._propensity_model is not None:
assert self._treated_col is not None
if self._training_strategy == "iptw":
X = X.loc[:, [self._treated_col]]
elif self._training_strategy == "aiptw":
if len(self._cox_fit_cols) > 0:
X = X.loc[:, [self._treated_col] + self._cox_fit_cols]
else:
pass
else:
assert self._training_strategy == "webdisco"
if len(self._cox_fit_cols) > 0:
X = X.loc[:, self._cox_fit_cols]
else:
pass

# If X is to be standardized we do it
if self._standardize_data:
if shared_state:
# Careful this shouldn't happen apart from the predict
means = shared_state["global_uncentered_moment_1"]
vars = shared_state["global_centered_moment_2"]
# Careful we need to match pandas and use unbiased estimator
bias_correction = (shared_state["total_n_samples"]) / float(
shared_state["total_n_samples"] - 1
)
self.global_moments = {
"means": means,
"vars": vars,
"bias_correction": bias_correction,
}
stds = vars.transform(lambda x: sqrt(x * bias_correction + self._tol))
X = X.sub(means)
X = X.div(stds)
else:
X = X.sub(self.global_moments["means"])
stds = self.global_moments["vars"].transform(
lambda x: sqrt(
x * self.global_moments["bias_correction"] + self._tol
)
)
X = X.div(stds)

X = X.to_numpy().astype("float64")

# If we have a propensity model we need to build X without the targets AND the
# treated column
if self._propensity_model is not None:
# We do not normalize the data for the propensity model !!!
Xprop = data_from_opener.drop(columns=columns_to_drop + [self._treated_col])
if self._propensity_fit_cols is not None:
Xprop = Xprop[self._propensity_fit_cols]
Xprop = Xprop.to_numpy().astype("float64")
else:
Xprop = None

# If WebDisco is used without propensity treated column does not exist
if self._treated_col is not None:
treated = (
data_from_opener[self._treated_col]
.to_numpy()
.astype("float64")
.reshape((-1, 1))
)
else:
treated = None

return (X, y, treated, Xprop)

def compute_X_y_and_propensity_weights(self, data_from_opener, shared_state):
"""Build appropriate X, y and weights from raw output of opener.
Expand All @@ -731,26 +615,26 @@ def compute_X_y_and_propensity_weights(self, data_from_opener, shared_state):
Returns
-------
tuple
_description_
X input to the Cox model, y target of Cox model, weights propensity weights
"""
X, y, treated, Xprop = self.build_X_y(data_from_opener, shared_state)
if self._propensity_model is not None:
assert (
treated is not None
), f"""If you are using a propensity model the {self._treated_col} (Treated)
column should be available"""
assert np.all(
np.in1d(np.unique(treated.astype("uint8"))[0], [0, 1])
), "The treated column should have all its values in set([0, 1])"
Xprop = torch.from_numpy(Xprop)
with torch.no_grad():
propensity_scores = self._propensity_model(Xprop)

propensity_scores = propensity_scores.detach().numpy()
# We robustify the division
weights = treated * 1.0 / np.maximum(propensity_scores, self._tol) + (
1 - treated
) * 1.0 / (np.maximum(1.0 - propensity_scores, self._tol))
else:
weights = np.ones((X.shape[0], 1))
X, y, treated, Xprop, self.global_moments = build_X_y_function(
data_from_opener,
self._event_col,
self._duration_col,
self._treated_col,
self._target_cols,
self._standardize_data,
self._propensity_model,
self._cox_fit_cols,
self._propensity_fit_cols,
self._tol,
self._training_strategy,
shared_state=shared_state,
global_moments={}
if not hasattr(self, "global_moments")
else self.global_moments,
)
X, y, weights = compute_X_y_and_propensity_weights_function(
X, y, treated, Xprop, self._propensity_model, self._tol
)
return X, y, weights
Loading

0 comments on commit 440dafb

Please sign in to comment.