Skip to content

Commit

Permalink
Merge pull request #19 from Helsinki-NLP/fix/undo_forwardhook_sideeff…
Browse files Browse the repository at this point in the history
…ects

undo forward-hook side-effect
  • Loading branch information
TimotheeMickus authored Sep 25, 2023
2 parents 3b1678e + 8d5b19e commit e46eed2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 3 additions & 1 deletion onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ def build_task_specific_model(
def has_grad_hook(module, input, output) -> None:
for param in module.parameters(recurse=False):
if param.requires_grad:
param.has_grad = True
# NB: we're looking at whether gradient will/has been computed, which is only the
# case when the module is training.
param.has_grad = module.training

for module in nmt_model.modules():
module.register_forward_hook(has_grad_hook)
Expand Down
4 changes: 4 additions & 0 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ def validate(self, valid_iter, moving_average=None, task=None):
# Set model back to training mode.
valid_model.train()

for p in self.model.parameters():
if hasattr(p, 'has_grad'):
p.has_grad = False

return stats

def _gradient_accumulation_over_lang_pairs(
Expand Down

0 comments on commit e46eed2

Please sign in to comment.