Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Attieh authored and Joseph Attieh committed Oct 7, 2024
1 parent b234e2e commit 70bc59d
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from torch.nn.init import xavier_uniform_
from typing import Optional, List, Dict, Tuple
from mammoth.modules.x_transformers import TransformerWrapper
from mammoth.modules.x_tf import TransformerWrapper
from x_transformers.x_transformers import TokenEmbedding

from mammoth.distributed.components import (
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,11 @@ def validate(self, valid_iter, moving_average=None, task=None):
for batch, metadata, _ in valid_iter:
if stats is None:
stats = mammoth.utils.Statistics()

src, src_lengths = batch.src if isinstance(batch.src, tuple) else (batch.src, None)
decoder_input = batch.tgt[:-1]
target = batch.tgt[1:]

with torch.cuda.amp.autocast(enabled=self.optim.amp):
# F-prop through the model.
logits, decoder_output = valid_model(
Expand Down

0 comments on commit 70bc59d

Please sign in to comment.