From fd1735cec880f4c77380d802f38cdf47f9ec4aef Mon Sep 17 00:00:00 2001 From: JMGaljaard Date: Thu, 15 Sep 2022 21:20:49 +0200 Subject: [PATCH] Deprecate loading (default) models from file in Federated training --- fltk/core/node.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/fltk/core/node.py b/fltk/core/node.py index 9c1634ba..eea1ede9 100644 --- a/fltk/core/node.py +++ b/fltk/core/node.py @@ -4,6 +4,8 @@ import copy import os from typing import Callable, Any, Union + +import deprecate import torch from torch.distributed import rpc from fltk.datasets.federated import get_fed_dataset @@ -53,7 +55,7 @@ def _config(self, config: FedLearnerConfig): self.cuda = config.cuda self.device = self.init_device() self.distributed = config.distributed - self.set_net(self.load_default_model()) + self.net = get_net(self.config.net_name)() def init_dataloader(self, world_size: int = None): """ @@ -123,7 +125,7 @@ def set_net(self, net): :param net: """ self.net = net - self.net.to(self.device) + # self.net.to(self.device) def get_nn_parameters(self): """ @@ -131,15 +133,6 @@ def get_nn_parameters(self): """ return self.net.state_dict() - def load_default_model(self): - """ - Load a model from default model file. - This is used to ensure consistent default model behavior. - """ - model_class = get_net(self.config.net_name) - default_model_path = os.path.join(self.config.get_default_model_folder_path(), model_class.__name__ + ".model") - - return self.load_model_from_file(default_model_path) def load_model_from_file(self, model_file_path): """ @@ -160,10 +153,9 @@ def load_model_from_file(self, model_file_path): self.logger.warning(f"Could not find model: {model_file_path}") return model - def update_nn_parameters(self, new_params, is_offloaded_model = False): """ - Update the NN's parameters. + Update the NN's parameters by parameters provided by Federator. :param new_params: New weights for the neural network :type new_params: dict @@ -172,6 +164,7 @@ def update_nn_parameters(self, new_params, is_offloaded_model = False): pass # self.offloaded_net.load_state_dict(copy.deepcopy(new_params), strict=True) else: + self.logger.info("Updating parameters") self.net.load_state_dict(copy.deepcopy(new_params), strict=True) def message(self, other_node: str, method: Union[Callable, str], *args, **kwargs) -> torch.Future: # pylint: disable=no-member