Skip to content

Commit

Permalink
Enable torch DataParallel for multi-GPU (#602)
Browse files Browse the repository at this point in the history
* add data para

* add doc

* Update docstring

* format
  • Loading branch information
rayrayraykk authored May 22, 2023
1 parent 41ef4cd commit 9f8ce0e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
1 change: 1 addition & 0 deletions federatedscope/core/configs/cfg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def extend_training_cfg(cfg):

cfg.train.local_update_steps = 1
cfg.train.batch_or_epoch = 'batch'
cfg.train.data_para_dids = [] # `torch.nn.DataParallel` devices

cfg.train.optimizer = CN(new_allowed=True)
cfg.train.optimizer.type = 'SGD'
Expand Down
26 changes: 26 additions & 0 deletions federatedscope/core/trainers/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def evaluate(self, target_data_split_name="test"):
return self.ctx.eval_metrics

def register_default_hooks_train(self):
self.register_hook_in_train(self._hook_on_data_parallel_init,
"on_fit_start")
self.register_hook_in_train(self._hook_on_fit_start_init,
"on_fit_start")
self.register_hook_in_train(
Expand All @@ -118,6 +120,8 @@ def register_default_hooks_train(self):
self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end")

def register_default_hooks_ft(self):
self.register_hook_in_ft(self._hook_on_data_parallel_init,
"on_fit_start")
self.register_hook_in_ft(self._hook_on_fit_start_init, "on_fit_start")
self.register_hook_in_ft(self._hook_on_fit_start_calculate_model_size,
"on_fit_start")
Expand All @@ -137,6 +141,8 @@ def register_default_hooks_ft(self):

def register_default_hooks_eval(self):
# test/val
self.register_hook_in_eval(self._hook_on_data_parallel_init,
"on_fit_start")
self.register_hook_in_eval(self._hook_on_fit_start_init,
"on_fit_start")
self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start")
Expand All @@ -147,6 +153,26 @@ def register_default_hooks_eval(self):
self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end")
self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end")

def _hook_on_data_parallel_init(self, ctx):
"""
Note:
The modified attributes and according operations are shown below,
further modifications should be made to `ctx.model` other object:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.model`` Wrap ``nn.Module` to \
`nn.DataParallel`
================================== ===========================
"""
if isinstance(ctx.model, torch.nn.DataParallel):
return

if len(ctx.cfg.train.data_para_dids):
ctx.model = \
torch.nn.DataParallel(ctx.model,
device_ids=ctx.cfg.train.data_para_dids)

def _hook_on_fit_start_init(self, ctx):
"""
Note:
Expand Down

0 comments on commit 9f8ce0e

Please sign in to comment.