diff --git a/fedeca/strategies/fed_kaplan.py b/fedeca/strategies/fed_kaplan.py index ad3e1f3b..f385b430 100644 --- a/fedeca/strategies/fed_kaplan.py +++ b/fedeca/strategies/fed_kaplan.py @@ -1,4 +1,5 @@ """Compute federated Kaplan-Meier estimates.""" +import pickle as pk from typing import List, Optional, Union import numpy as np @@ -54,6 +55,11 @@ def __init__( self._treated_col = treated_col self._propensity_model = propensity_model self._tol = tol + self.kwargs["duration_col"] = duration_col + self.kwargs["event_col"] = event_col + self.kwargs["treated_col"] = treated_col + self.kwargs["propensity_model"] = propensity_model + self.kwargs["tol"] = tol def build_compute_plan( self, @@ -216,3 +222,35 @@ def compute_agg_km_curve(self, shared_states): """ t_agg, n_agg, d_agg = aggregate_events_statistics(shared_states) return km_curve(t_agg, n_agg, d_agg) + + def save_local_state(self, path: Path): + """Save the object on the disk. + + Should be used only by the backend, to define the local_state. + + Parameters + ---------- + path : Path + Where to save the object. + """ + with open(path, "wb") as file: + pk.dump(self.statistics_result, file) + + def load_local_state(self, path: Path) -> Any: + """Load the object from the disk. + + Should be used only by the backend, to define the local_state. + + Parameters + ---------- + path : Path + Where to find the object. + + Returns + ------- + Any + Previously saved instance. + """ + with open(path, "rb") as file: + self.statistics_result = pk.load(file) + return self