Skip to content

Commit

Permalink
Fix a bug with report_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh committed Dec 9, 2024
1 parent feed7ef commit 75d2670
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ def train(
if device_context.is_master():
sampled_task_counts = self.task_queue_manager.sampled_task_counts
else:
sampled_task_counts = None
from collections import Counter
sampled_task_counts = Counter()
report_stats = self._maybe_report_training(
step,
train_steps,
Expand Down
2 changes: 1 addition & 1 deletion mammoth/utils/report_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def report_training(
patience,
report_stats,
multigpu=False,
sampled_task_counts=None,
sampled_task_counts={},
optimizer=None,
):
"""
Expand Down
2 changes: 1 addition & 1 deletion mammoth/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, loss=0, n_words=0, n_correct=None):
self.magnitude_denom = 0
self.param_magnitudes = Counter()
self.grad_magnitudes = Counter()

@classmethod
def from_loss_logits_target(cls, loss: float, logits, target, padding_idx):
"""
Expand Down

0 comments on commit 75d2670

Please sign in to comment.