diff --git a/federatedscope/core/configs/cfg_training.py b/federatedscope/core/configs/cfg_training.py index 6e98c3623..baa8f915d 100644 --- a/federatedscope/core/configs/cfg_training.py +++ b/federatedscope/core/configs/cfg_training.py @@ -32,6 +32,7 @@ def extend_training_cfg(cfg): cfg.train.local_update_steps = 1 cfg.train.batch_or_epoch = 'batch' + cfg.train.data_para_dids = [] # `torch.nn.DataParallel` devices cfg.train.optimizer = CN(new_allowed=True) cfg.train.optimizer.type = 'SGD' diff --git a/federatedscope/core/trainers/torch_trainer.py b/federatedscope/core/trainers/torch_trainer.py index fd5a72c53..6ac7b98a2 100644 --- a/federatedscope/core/trainers/torch_trainer.py +++ b/federatedscope/core/trainers/torch_trainer.py @@ -98,6 +98,8 @@ def evaluate(self, target_data_split_name="test"): return self.ctx.eval_metrics def register_default_hooks_train(self): + self.register_hook_in_train(self._hook_on_data_parallel_init, + "on_fit_start") self.register_hook_in_train(self._hook_on_fit_start_init, "on_fit_start") self.register_hook_in_train( @@ -118,6 +120,8 @@ def register_default_hooks_train(self): self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end") def register_default_hooks_ft(self): + self.register_hook_in_ft(self._hook_on_data_parallel_init, + "on_fit_start") self.register_hook_in_ft(self._hook_on_fit_start_init, "on_fit_start") self.register_hook_in_ft(self._hook_on_fit_start_calculate_model_size, "on_fit_start") @@ -137,6 +141,8 @@ def register_default_hooks_ft(self): def register_default_hooks_eval(self): # test/val + self.register_hook_in_eval(self._hook_on_data_parallel_init, + "on_fit_start") self.register_hook_in_eval(self._hook_on_fit_start_init, "on_fit_start") self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start") @@ -147,6 +153,26 @@ def register_default_hooks_eval(self): self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end") self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end") + def _hook_on_data_parallel_init(self, ctx): + """ + Note: + The modified attributes and according operations are shown below, + further modifications should be made to `ctx.model` other object: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.model`` Wrap ``nn.Module` to \ + `nn.DataParallel` + ================================== =========================== + """ + if isinstance(ctx.model, torch.nn.DataParallel): + return + + if len(ctx.cfg.train.data_para_dids): + ctx.model = \ + torch.nn.DataParallel(ctx.model, + device_ids=ctx.cfg.train.data_para_dids) + def _hook_on_fit_start_init(self, ctx): """ Note: