diff --git a/mammoth/trainer.py b/mammoth/trainer.py index c1e447a5..8dfd4efc 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -311,7 +311,8 @@ def train( if device_context.is_master(): sampled_task_counts = self.task_queue_manager.sampled_task_counts else: - sampled_task_counts = None + from collections import Counter + sampled_task_counts = Counter() report_stats = self._maybe_report_training( step, train_steps, diff --git a/mammoth/utils/report_manager.py b/mammoth/utils/report_manager.py index 9a7d4a91..08911124 100644 --- a/mammoth/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -59,7 +59,7 @@ def report_training( patience, report_stats, multigpu=False, - sampled_task_counts=None, + sampled_task_counts={}, optimizer=None, ): """ diff --git a/mammoth/utils/statistics.py b/mammoth/utils/statistics.py index 0c680360..e5d8ce6c 100644 --- a/mammoth/utils/statistics.py +++ b/mammoth/utils/statistics.py @@ -34,7 +34,7 @@ def __init__(self, loss=0, n_words=0, n_correct=None): self.magnitude_denom = 0 self.param_magnitudes = Counter() self.grad_magnitudes = Counter() - + @classmethod def from_loss_logits_target(cls, loss: float, logits, target, padding_idx): """