diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 0572385e..24d3fb00 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -276,6 +276,8 @@ def train( continue # logger.warning(f'Syncing {component.get_name()}') # DEBUG params = component.named_parameters(self.model) + # gradient_sync.gradient_norm counts the number of devices that trained this component + # this doesn't normalize the number of masked tokens mammoth.distributed.externally_managed_reduce_and_rescale_grads( named_parameters=params, has_local_gradient=gradient_sync.has_local_gradient, @@ -441,9 +443,9 @@ def _gradient_accumulation( # update data state self._data_state[metadata.corpus_id] = batch.line_idx - num_tokens = batch.tgt.mask.sum() + num_tokens = batch.tgt.mask.sum().item() if self.norm_method == "tokens": - normalization += num_tokens.item() + normalization += num_tokens else: normalization += batch.batch_size report_stats.n_src_words += batch.src.mask.sum().item() diff --git a/mammoth/utils/statistics.py b/mammoth/utils/statistics.py index 02808801..ca3a0021 100644 --- a/mammoth/utils/statistics.py +++ b/mammoth/utils/statistics.py @@ -42,11 +42,14 @@ def from_loss_logits_target(cls, loss: float, logits, target, padding_idx): loss, model prediction logits, and target indices. Note that this is heavy. Only use for validation / debug purposes. """ - pred = logits.max(1)[1] + target = target.squeeze(-1) + pred = logits.max(dim=-1).indices + correct = pred.eq(target) non_padding = target.ne(padding_idx) - num_correct = pred.eq(target).masked_select(non_padding).sum().item() + correct_not_padded = correct.masked_select(non_padding) + num_correct = correct_not_padded.sum().item() num_non_padding = non_padding.sum().item() - cls(loss, num_non_padding, num_correct) + return cls(loss, num_non_padding, num_correct) @staticmethod def all_gather_stats(stat, max_size=4096): @@ -199,8 +202,12 @@ def log_tensorboard(self, prefix, writer, learning_rate, patience, step): """display statistics to tensorboard""" t = self.elapsed_time() writer.add_scalar(prefix + "/xent", self.xent(), step) - writer.add_scalar(prefix + "/ppl", self.ppl(), step) - writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) + ppl = self.ppl() + if ppl is not None: + writer.add_scalar(prefix + "/ppl", ppl, step) + acc = self.accuracy() + if acc is not None: + writer.add_scalar(prefix + "/accuracy", acc, step) writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) # writer.add_scalar(prefix + "/lr", learning_rate, step) if patience is not None: