Skip to content

Commit

Permalink
first trial at fixing things
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Aug 16, 2024
1 parent 74dc50f commit 24ab75e
Showing 1 changed file with 76 additions and 17 deletions.
93 changes: 76 additions & 17 deletions fedeca/strategies/fed_smd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
treated_col: str,
propensity_model: torch.nn.Module,
client_identifier: str,
use_unweighted_variance: bool = True,
tol: float = 1e-16,
):
"""Initialize FedSMD strategy.
Expand All @@ -49,6 +50,10 @@ def __init__(
propensity_model : Union[None, nn.Module], optional
_description_, by default None
client_identifier : str
_description_
use_unweighted_variance : bool, optional
_description_, by default True
"""
super().__init__()

Expand All @@ -59,6 +64,7 @@ def __init__(
self._propensity_model = propensity_model
self._propensity_fit_cols = None
self._client_identifier = client_identifier
self._use_unweighted_variance = use_unweighted_variance
self._tol = tol
self.statistics_result = None

Expand All @@ -68,6 +74,7 @@ def __init__(
self.kwargs["treated_col"] = treated_col
self.kwargs["propensity_model"] = propensity_model
self.kwargs["client_identifier"] = client_identifier
self.kwargs["use_unweighted_variance"] = use_unweighted_variance
self.kwargs["tol"] = tol

def build_compute_plan(
Expand Down Expand Up @@ -176,7 +183,6 @@ def compute_local_moments_per_group(
"iptw",
)
# X contains only the treatment column (strategy == iptw)

X, _, weights = compute_X_y_and_propensity_weights_function(
X, y, treated, Xprop, self._propensity_model, self._tol
)
Expand All @@ -185,29 +191,48 @@ def compute_local_moments_per_group(
# we use Xprop which contain all propensity columns, which
# are the only ones we are interested in
raw_data = pd.DataFrame(Xprop, columns=propensity_cols)
weighted_data = pd.DataFrame(
np.multiply(weights, Xprop), columns=propensity_cols

weights_df = pd.DataFrame(
np.repeat(weights[:, None], Xprop.shape[1]), columns=propensity_cols
)

results = {}
for treatment in [0, 1]:
mask_treatment = treated == treatment
res_name = "treated" if treatment else "untreated"
results[res_name] = {}
# Here we pass weights
results[res_name]["weighted"] = {
f"moment{k}": compute_uncentered_moment(
weighted_data[mask_treatment], k
raw_data[mask_treatment], k, weights
)
for k in range(1, 3)
}
# Here we don't
results[res_name]["unweighted"] = {
f"moment{k}": compute_uncentered_moment(raw_data[mask_treatment], k)
for k in range(1, 3)
}
results[res_name]["unweighted"]["n_samples"] = results[res_name][
"weighted"
]["n_samples"] = (
# Here we compute aggregated effective sample size (ess)
results[res_name]["unweighted"]["n_samples"] = (
raw_data[mask_treatment].select_dtypes(include=np.number).count()
)
results[res_name]["weighted"]["n_samples"] = (
weights_df[mask_treatment].select_dtypes(include=np.number).sum()
)
if not self._use_unweighted_variance:
# We add these numbers for scaling variance as in https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4626409/ # noqa: E501
results[res_name]["weighted"]["weights"]["moment1_sum"] = (
weights_df[mask_treatment].select_dtypes(include=np.number).sum()
)
# Here we compute squared weights sum
results[res_name]["weighted"]["weights"]["moment2_sum"] = (
weights_df[mask_treatment]
.pow(2)
.select_dtypes(include=np.number)
.sum()
)

return results

@remote
Expand All @@ -231,33 +256,67 @@ def compute_smd(
def std_mean_differences(x, y):
"""Compute standardized mean differences."""
means_x = x["global_uncentered_moment_1"]
# we match nump std with 0 ddof contrary to standarization for Cox
stds_x = np.sqrt(x["global_centered_moment_2"] + self._tol)

means_y = y["global_uncentered_moment_1"]
# we match nump std with 0 ddof contrary to standarization for Cox
stds_y = np.sqrt(y["global_centered_moment_2"] + self._tol)

smd_df = means_x.subtract(means_y).div(
stds_x.pow(2).add(stds_y.pow(2)).div(2).pow(0.5)
)
var_x = x["global_centered_moment_2"]
var_y = y["global_centered_moment_2"]
smd_df = means_x.subtract(means_y).div(var_x.add(var_y).div(2).pow(0.5))
return smd_df

# First we compute means and vars wo weights
treated_raw = compute_global_moments(
[shared_state["treated"]["unweighted"] for shared_state in shared_states]
)
untreated_raw = compute_global_moments(
[shared_state["untreated"]["unweighted"] for shared_state in shared_states]
)

# We use directly to compute the SMD before propensity weighting
smd_raw = std_mean_differences(treated_raw, untreated_raw)

# Then we compute means and vars WITH WEIGHTS
treated_weighted = compute_global_moments(
[shared_state["treated"]["weighted"] for shared_state in shared_states]
)
untreated_weighted = compute_global_moments(
[shared_state["untreated"]["weighted"] for shared_state in shared_states]
)
if not self._use_unweighted_variance:
# we compute the var scaler for treated population
def compute_var_scaler(weights_per_client):
total_weighted_n_samples = sum(
[s["sum_moment1"].iloc[0] for s in weights_per_client]
)
total_weighted_n_samples_squared = sum(
[s["sum_moment2"].iloc[0] for s in weights_per_client]
)
return total_weighted_n_samples**2 / (
total_weighted_n_samples**2 - total_weighted_n_samples_squared
)

var_scaler_treated = compute_var_scaler(
[
shared_state["treated"]["weighted"]["weights"]
for shared_state in shared_states
]
)
treated_weighted["global_centered_moment_2"] *= var_scaler_treated

# we compute the var scaler for untreated population
var_scaler_untreated = compute_var_scaler(
[
shared_state["untreated"]["weighted"]["weights"]
for shared_state in shared_states
]
)

untreated_weighted["global_centered_moment_2"] *= var_scaler_untreated

if self._use_unweighted_variance:
treated_weighted["global_centered_moment_2"] = treated_raw[
"global_centered_moment_2"
]
untreated_weighted["global_centered_moment_2"] = untreated_raw[
"global_centered_moment_2"
]

smd_weighted = std_mean_differences(treated_weighted, untreated_weighted)

Expand Down

0 comments on commit 24ab75e

Please sign in to comment.