diff --git a/onmt/model_builder.py b/onmt/model_builder.py index bb0f0988..69abb0a7 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -347,7 +347,9 @@ def build_task_specific_model( def has_grad_hook(module, input, output) -> None: for param in module.parameters(recurse=False): if param.requires_grad: - param.has_grad = True + # NB: we're looking at whether gradient will/has been computed, which is only the + # case when the module is training. + param.has_grad = module.training for module in nmt_model.modules(): module.register_forward_hook(has_grad_hook) diff --git a/onmt/trainer.py b/onmt/trainer.py index f0fa8efd..7a3dc79a 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -447,6 +447,10 @@ def validate(self, valid_iter, moving_average=None, task=None): # Set model back to training mode. valid_model.train() + for p in self.model.parameters(): + if hasattr(p, 'has_grad'): + p.has_grad = False + return stats def _gradient_accumulation_over_lang_pairs(