From 70bc59dad641cb65b0372510558ecf2396c705ca Mon Sep 17 00:00:00 2001 From: Joseph Attieh Date: Mon, 7 Oct 2024 14:33:06 +0300 Subject: [PATCH] changes --- mammoth/model_builder.py | 2 +- mammoth/modules/{x_transformers.py => x_tf.py} | 0 mammoth/trainer.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) rename mammoth/modules/{x_transformers.py => x_tf.py} (100%) diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 02619175..97d08281 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -8,7 +8,7 @@ from functools import partial from torch.nn.init import xavier_uniform_ from typing import Optional, List, Dict, Tuple -from mammoth.modules.x_transformers import TransformerWrapper +from mammoth.modules.x_tf import TransformerWrapper from x_transformers.x_transformers import TokenEmbedding from mammoth.distributed.components import ( diff --git a/mammoth/modules/x_transformers.py b/mammoth/modules/x_tf.py similarity index 100% rename from mammoth/modules/x_transformers.py rename to mammoth/modules/x_tf.py diff --git a/mammoth/trainer.py b/mammoth/trainer.py index c8a5be9e..6991355d 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -382,11 +382,11 @@ def validate(self, valid_iter, moving_average=None, task=None): for batch, metadata, _ in valid_iter: if stats is None: stats = mammoth.utils.Statistics() - + src, src_lengths = batch.src if isinstance(batch.src, tuple) else (batch.src, None) decoder_input = batch.tgt[:-1] target = batch.tgt[1:] - + with torch.cuda.amp.autocast(enabled=self.optim.amp): # F-prop through the model. logits, decoder_output = valid_model(