Skip to content

Commit

Permalink
Bugfix to --report_training_accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Sep 30, 2024
1 parent 43c71e8 commit ad9de67
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
6 changes: 4 additions & 2 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def train(
continue
# logger.warning(f'Syncing {component.get_name()}') # DEBUG
params = component.named_parameters(self.model)
# gradient_sync.gradient_norm counts the number of devices that trained this component
# this doesn't normalize the number of masked tokens
mammoth.distributed.externally_managed_reduce_and_rescale_grads(
named_parameters=params,
has_local_gradient=gradient_sync.has_local_gradient,
Expand Down Expand Up @@ -441,9 +443,9 @@ def _gradient_accumulation(
# update data state
self._data_state[metadata.corpus_id] = batch.line_idx

num_tokens = batch.tgt.mask.sum()
num_tokens = batch.tgt.mask.sum().item()
if self.norm_method == "tokens":
normalization += num_tokens.item()
normalization += num_tokens
else:
normalization += batch.batch_size
report_stats.n_src_words += batch.src.mask.sum().item()
Expand Down
17 changes: 12 additions & 5 deletions mammoth/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ def from_loss_logits_target(cls, loss: float, logits, target, padding_idx):
loss, model prediction logits, and target indices.
Note that this is heavy. Only use for validation / debug purposes.
"""
pred = logits.max(1)[1]
target = target.squeeze(-1)
pred = logits.max(dim=-1).indices
correct = pred.eq(target)
non_padding = target.ne(padding_idx)
num_correct = pred.eq(target).masked_select(non_padding).sum().item()
correct_not_padded = correct.masked_select(non_padding)
num_correct = correct_not_padded.sum().item()
num_non_padding = non_padding.sum().item()
cls(loss, num_non_padding, num_correct)
return cls(loss, num_non_padding, num_correct)

@staticmethod
def all_gather_stats(stat, max_size=4096):
Expand Down Expand Up @@ -199,8 +202,12 @@ def log_tensorboard(self, prefix, writer, learning_rate, patience, step):
"""display statistics to tensorboard"""
t = self.elapsed_time()
writer.add_scalar(prefix + "/xent", self.xent(), step)
writer.add_scalar(prefix + "/ppl", self.ppl(), step)
writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
ppl = self.ppl()
if ppl is not None:
writer.add_scalar(prefix + "/ppl", ppl, step)
acc = self.accuracy()
if acc is not None:
writer.add_scalar(prefix + "/accuracy", acc, step)
writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
# writer.add_scalar(prefix + "/lr", learning_rate, step)
if patience is not None:
Expand Down

0 comments on commit ad9de67

Please sign in to comment.