From 75304d4a703057c2913b30b87d341cfd0920652c Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Fri, 22 Sep 2023 17:59:43 +0300 Subject: [PATCH 1/5] undo forward-hook side-effect --- onmt/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onmt/trainer.py b/onmt/trainer.py index f0fa8efd..8d210df3 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -447,6 +447,10 @@ def validate(self, valid_iter, moving_average=None, task=None): # Set model back to training mode. valid_model.train() + for p in self.model.parameters(): + if hasattr(p, 'has_grad'): + p.has_grad = False + return stats def _gradient_accumulation_over_lang_pairs( From 5763591fe28e2d5366936c47a30e483a67235d52 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Fri, 22 Sep 2023 18:03:33 +0300 Subject: [PATCH 2/5] linting and comment --- onmt/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onmt/trainer.py b/onmt/trainer.py index 8d210df3..7edb9039 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -447,10 +447,11 @@ def validate(self, valid_iter, moving_average=None, task=None): # Set model back to training mode. valid_model.train() + # the forward hook `has_grad` was triggered, so we manually unset the flags to not fool the optim for p in self.model.parameters(): if hasattr(p, 'has_grad'): p.has_grad = False - + return stats def _gradient_accumulation_over_lang_pairs( From 30ba781cf70dd660a6b6c96e658d948b48106ea2 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Mon, 25 Sep 2023 10:42:13 +0300 Subject: [PATCH 3/5] Revert "linting and comment" This reverts commit 5763591fe28e2d5366936c47a30e483a67235d52. --- onmt/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onmt/trainer.py b/onmt/trainer.py index 7edb9039..8d210df3 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -447,11 +447,10 @@ def validate(self, valid_iter, moving_average=None, task=None): # Set model back to training mode. valid_model.train() - # the forward hook `has_grad` was triggered, so we manually unset the flags to not fool the optim for p in self.model.parameters(): if hasattr(p, 'has_grad'): p.has_grad = False - + return stats def _gradient_accumulation_over_lang_pairs( From a98a2741611caf9e5753cff08314f6b5d77d0d78 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Mon, 25 Sep 2023 11:19:01 +0300 Subject: [PATCH 4/5] subtler fix --- onmt/model_builder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index bb0f0988..69abb0a7 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -347,7 +347,9 @@ def build_task_specific_model( def has_grad_hook(module, input, output) -> None: for param in module.parameters(recurse=False): if param.requires_grad: - param.has_grad = True + # NB: we're looking at whether gradient will/has been computed, which is only the + # case when the module is training. + param.has_grad = module.training for module in nmt_model.modules(): module.register_forward_hook(has_grad_hook) From 8d5b19e70d7e407a291ab733019fce0e5977fe11 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Mon, 25 Sep 2023 11:24:01 +0300 Subject: [PATCH 5/5] i love linting --- onmt/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/trainer.py b/onmt/trainer.py index 8d210df3..7a3dc79a 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -450,7 +450,7 @@ def validate(self, valid_iter, moving_average=None, task=None): for p in self.model.parameters(): if hasattr(p, 'has_grad'): p.has_grad = False - + return stats def _gradient_accumulation_over_lang_pairs(