Skip to content

Commit

Permalink
add bn buffer sync in eval_hook (open-mmlab#675)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin authored Mar 4, 2021
1 parent ea0e722 commit b002291
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions mmaction/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import warnings
from math import inf

import torch.distributed as dist
from mmcv.runner import Hook
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -285,6 +287,9 @@ class DistEvalHook(EvalHook):
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
broadcast_bn_buffer (bool): Whether to broadcast the
buffer(running_mean and running_var) of rank 0 to other rank
before evaluation. Default: True.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""
Expand All @@ -296,6 +301,7 @@ def __init__(self,
by_epoch=True,
save_best='auto',
rule=None,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
**eval_kwargs):
Expand All @@ -307,10 +313,25 @@ def __init__(self,
save_best=save_best,
rule=rule,
**eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect

def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
# Synchronization of BatchNorm's buffer (running_mean
# and running_var) is not supported in the DDP of pytorch,
# which may cause the inconsistent performance of models in
# different ranks, so we broadcast BatchNorm's buffers
# of rank 0 to other ranks to avoid this.
if self.broadcast_bn_buffer:
model = runner.model
for name, module in model.named_modules():
if isinstance(module,
_BatchNorm) and module.track_running_stats:
dist.broadcast(module.running_var, 0)
dist.broadcast(module.running_mean, 0)

if not self.evaluation_flag(runner):
return

Expand Down

0 comments on commit b002291

Please sign in to comment.