Skip to content

Commit

Permalink
supposedly working if no reinstantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 12, 2024
1 parent 71dc969 commit e59dff6
Showing 1 changed file with 54 additions and 81 deletions.
135 changes: 54 additions & 81 deletions fedeca/strategies/bootstraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,73 +118,49 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] =
_aggregate_all_bootstraps(agg, name)
for agg, name in zip(orig_aggregations[key], aggregations_names[key])
]
# define what will be the __init__ method of
# the to-be-dynamically-defined MergedClass
def bootstrapped_algo_init(self, *args, **kwargs):
"""Initialize the merged strategy.
Parameters
----------
self : BtstAlgo
The BtstAlgo instance.
algo : Strategy
List of the strategies used.
args: Any
extra arguments
kwargs: Any
extra keyword arguments
"""
super(self.__class__, self).__init__(
*args, **kwargs
)
self.original_algo = strategy.algo

# Dict holding the methods of the to-be-dynamically-defined MergedClass
methods_dict = dict(zip(local_functions_names["algo"], local_computations_fct["algo"]))
methods_dict.update(dict(zip(aggregations_names["algo"], aggregations_fct["algo"])))
methods_dict.update({"load_local_state_original": getattr(strategy.algo, "load_local_state")})
methods_dict.update({"save_local_state_original": getattr(strategy.algo, "save_local_state")})
methods_dict.update({"load_local_state": _load_all_bootstraps_states()})
methods_dict.update({"save_local_state": _save_all_bootstraps_states(bootstrap_seeds_list)})
methods_dict.update({"strategies": strategy.algo.strategies})
methods_dict.update({"__init__": bootstrapped_algo_init})

# dynamically define the BtstStrategy Class, which inherits from
# Strategy, and whose methods are defined by the method dict.
BtstAlgo = type("BtstAlgo", (strategy.algo.__class__,), methods_dict)
btst_algo = BtstAlgo()

# define what will be the __init__ method of
# the to-be-dynamically-defined MergedClass
def bootstrapped_strategy_init(self, *args, **kwargs):
"""Initialize the merged strategy.
Parameters
----------
self : BtstStrategy
The BtstStrategy instance.
strategy : Strategy
List of the strategies used.
args: Any
extra arguments
kwargs: Any
extra keyword arguments
"""
super(self.__class__, self).__init__(
*args, **kwargs
)
self.original_strategy = strategy

# Dict holding the methods of the to-be-dynamically-defined MergedClass
methods_dict = dict(zip(local_functions_names["strategy"], local_computations_fct["strategy"]))
methods_dict.update(dict(zip(aggregations_names["strategy"], aggregations_fct["strategy"])))
methods_dict.update({"__init__": bootstrapped_strategy_init})
# dynamically define the BtstStrategy Class, which inherits from
# Strategy, and whose methods are defined by the method dict.
BtstStrategy = type("BtstStrategy", (strategy.__class__,), methods_dict)
# return an instance of this class.
strat = BtstStrategy(algo=btst_algo)
return strat
# We have to overwrite the original methods at the class level
# obj_class = strategy.algo.__class__
obj = strategy.algo
for local_name in local_functions_names["algo"]:
# f = types.MethodType(_bootstrap_local_function(getattr(obj_class, local_name), local_name, bootstrap_seeds_list), obj_class)
# setattr(obj_class, local_name, f)
f = types.MethodType(_bootstrap_local_function(getattr(obj, local_name), local_name, bootstrap_seeds_list), obj)
setattr(obj, local_name, f)
for agg_name in aggregations_names["algo"]:
# f = types.MethodType(_aggregate_all_bootstraps(getattr(obj_class, agg_name), agg_name), obj_class)
# setattr(obj_class, agg_name, f)
f = types.MethodType(_aggregate_all_bootstraps(getattr(obj, agg_name), agg_name), obj)
setattr(obj, agg_name, f)

# f = types.MethodType(_save_all_bootstraps_states(getattr(obj_class, "save_local_state"), bootstrap_seeds_list), obj_class)
# setattr(obj_class, "save_local_state", f)

# f = types.MethodType(_load_all_bootstraps_states(getattr(obj_class, "load_local_state")), obj_class)
# setattr(obj_class, "load_local_state", f)

f = types.MethodType(_save_all_bootstraps_states(getattr(obj, "save_local_state"), bootstrap_seeds_list), obj)
setattr(obj, "save_local_state", f)

f = types.MethodType(_load_all_bootstraps_states(getattr(obj, "load_local_state")), obj)
setattr(obj, "load_local_state", f)



# obj_class = strategy.__class__
obj = strategy
for local_name in local_functions_names["strategy"]:
# f = types.MethodType(_bootstrap_local_function(getattr(obj_class, local_name), local_name, bootstrap_seeds_list), obj_class)
# setattr(obj_class, local_name, f)
f = types.MethodType(_bootstrap_local_function(getattr(obj, local_name), local_name, bootstrap_seeds_list), obj)
setattr(obj, local_name, f)
for agg_name in aggregations_names["strategy"]:
# f = types.MethodType(_aggregate_all_bootstraps(getattr(obj_class, agg_name), agg_name), obj_class)
# setattr(obj_class, agg_name, f)
f = types.MethodType(_aggregate_all_bootstraps(getattr(obj, agg_name), agg_name), obj)
setattr(obj, agg_name, f)


return strategy
# # Very important we have to decorate AT THE CLASS LEVEL
# # here we decorate both at the instance and at the class level
# # but for actual deployments only class-level is important
Expand Down Expand Up @@ -258,27 +234,26 @@ def local_computation(self, datasamples, shared_state=None) -> list:
results = []
# loop over the provided local_computation steps using skip=True.
# What is highly non-trivial is that algo has a state that is bootstrap
# dependent and we need to load the correspponding state as the main
# dependent and we need to load the corresponding state as the main
# state, so we need to have saved all states (aka i.e. n_bootstraps models)
# We use implicitly the new method load_bootstrap_states to load all states in-RAM

name_decorated_function = local_computation.__name__ + "_original"
if not hasattr(self, "checkpoints_list"):
self.checkpoints_list = [None] * len(bootstrap_seeds_list)

for idx, seed in enumerate(bootstrap_seeds_list):
rng = np.random.default_rng(seed)
bootstrapped_data = datasamples.sample(datasamples.shape[0], replace=True, random_state=rng)

# Loading the correct state into the current main algo
if self.checkpoints_list[idx] is not None:
self._update_from_checkpoint(self.checkpoints_list[idx])
# We need this old state tto avoid side effects from the function
# We need this old state to avoid side effects from the function
# on the instance
old_state = copy.deepcopy(self)
if shared_state is None:
res = local_function(datasamples=bootstrapped_data, _skip=True)
res = getattr(self, name_decorated_function)(datasamples=bootstrapped_data, _skip=True)
else:
res = local_function(
res = getattr(self, name_decorated_function)(
datasamples=bootstrapped_data, shared_state=shared_state[idx], _skip=True
)
self.checkpoints_list[idx] = self._get_state_to_save()
Expand All @@ -287,7 +262,6 @@ def local_computation(self, datasamples, shared_state=None) -> list:
for att_name, att in vars(self).items():
if att != old_state.__getattribute__(att_name):
self.__setattr__(att_name, old_state.__getattribute__(att_name))

results.append(res)

return results
Expand Down Expand Up @@ -355,7 +329,7 @@ def aggregation(self, shared_states=None) -> list:
return remote(aggregation)


def _load_all_bootstraps_states():
def _load_all_bootstraps_states(load_local_state):
def load_local_state(self, path: Path) -> "TorchAlgo":
"""Load the stateful arguments of this class.
Child classes do not need to override that function.
Expand All @@ -375,14 +349,14 @@ def load_local_state(self, path: Path) -> "TorchAlgo":
checkpoints_found = [p for p in Path(tmpdirname).glob("**/bootstrap_*")]
self.checkpoints_list = [None] * len(checkpoints_found)
for idx, file in enumerate(checkpoints_found):
self.load_local_state_original(file)
load_local_state(file)
self.checkpoints_list[idx] = self._get_state_to_save()
return self

return load_local_state


def _save_all_bootstraps_states(bootstrap_seeds_list):
def _save_all_bootstraps_states(save_local_state, bootstrap_seeds_list):
def save_local_state(self, path: Path) -> "TorchAlgo":
# We save all bootstrapped states in different subfolders
# It assumes at this point checkpoints_list has been populated
Expand All @@ -402,7 +376,7 @@ def save_local_state(self, path: Path) -> "TorchAlgo":
self._update_from_checkpoint(checkpt)
# TODO methods implictly use the self attribute
path_to_checkpoint = Path(tmpdirname) / f"bootstrap_{idx}"
self.save_local_state_original(path_to_checkpoint)
save_local_state(path_to_checkpoint)
paths_to_checkpoints.append(path_to_checkpoint)

with zipfile.ZipFile(path, 'w') as f:
Expand Down Expand Up @@ -480,13 +454,12 @@ def __init__(self):
strategy = FedAvg(algo=TorchLogReg())

btst_strategy = make_bootstrap_strategy(strategy, n_bootstraps=10)


clients, train_data_nodes, test_data_nodes, _, _ = split_dataframe_across_clients(
df,
n_clients=2,
split_method= "split_control_over_centers",
split_method_kwargs={"treatment_info": "treatment"},
split_method= "uniform",
split_method_kwargs=None,
data_path="./data",
backend_type="subprocess",
)
Expand Down

0 comments on commit e59dff6

Please sign in to comment.