diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 9836fe36df..baba638bad 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -3,7 +3,9 @@ import warnings from math import inf +import torch.distributed as dist from mmcv.runner import Hook +from torch.nn.modules.batchnorm import _BatchNorm from torch.utils.data import DataLoader @@ -285,6 +287,9 @@ class DistEvalHook(EvalHook): processes. Default: None. gpu_collect (bool): Whether to use gpu or cpu to collect results. Default: False. + broadcast_bn_buffer (bool): Whether to broadcast the + buffer(running_mean and running_var) of rank 0 to other rank + before evaluation. Default: True. **eval_kwargs: Evaluation arguments fed into the evaluate function of the dataset. """ @@ -296,6 +301,7 @@ def __init__(self, by_epoch=True, save_best='auto', rule=None, + broadcast_bn_buffer=True, tmpdir=None, gpu_collect=False, **eval_kwargs): @@ -307,10 +313,25 @@ def __init__(self, save_best=save_best, rule=rule, **eval_kwargs) + self.broadcast_bn_buffer = broadcast_bn_buffer self.tmpdir = tmpdir self.gpu_collect = gpu_collect def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + if not self.evaluation_flag(runner): return