diff --git a/examples/t4rec_paper_experiments/t4r_paper_repro/transf_exp_main.py b/examples/t4rec_paper_experiments/t4r_paper_repro/transf_exp_main.py index 5da463b16c..c7ae93c6ce 100644 --- a/examples/t4rec_paper_experiments/t4r_paper_repro/transf_exp_main.py +++ b/examples/t4rec_paper_experiments/t4r_paper_repro/transf_exp_main.py @@ -22,7 +22,6 @@ import numpy as np import pandas as pd import torch -import transformers from exp_outputs import ( config_dllogger, creates_output_dir, @@ -33,14 +32,15 @@ from merlin.io import Dataset from merlin.schema import Tags from transf_exp_args import DataArguments, ModelArguments, TrainingArguments -from transformers import HfArgumentParser, set_seed -from transformers.trainer_utils import is_main_process +import transformers import transformers4rec.torch as t4r from merlin_standard_lib import Schema +from transformers import HfArgumentParser, set_seed from transformers4rec.torch import Trainer from transformers4rec.torch.utils.data_utils import MerlinDataLoader from transformers4rec.torch.utils.examples_utils import wipe_memory +from transformers.trainer_utils import is_main_process logger = logging.getLogger(__name__) @@ -224,7 +224,7 @@ def mask_last_interaction(x): logger.info(f"Recall@10 of manually masked test data = {str(recall_10)}") output_file = os.path.join(training_args.output_dir, "eval_results_over_time.txt") with open(output_file, "a") as writer: - writer.write(f"\n***** Recall@10 of simulated inference = {recall_10} *****\n") + writer.write(f"\n***** Recall@10 of simulated inference = {recall_10} *****\n") # Verify that the recall@10 from train.evaluate() matches the recall@10 calculated manually if not isinstance(input_module.masking, t4r.masking.PermutationLanguageModeling): # TODO fix inference discrepancy for permutation language modeling diff --git a/transformers4rec/torch/experimental.py b/transformers4rec/torch/experimental.py index 38850b6c30..4631c60b9c 100644 --- a/transformers4rec/torch/experimental.py +++ b/transformers4rec/torch/experimental.py @@ -97,7 +97,7 @@ def forward(self, inputs, training=False, testing=False, **kwargs): output = seq_rep + context_rep else: raise ValueError( - f"The aggregation {self.fusion_aggregation} is not supported," + f"The aggregation {self.fusion_aggregation} is not supported, " f"please select one of the following aggregations " f"['concat', 'elementwise-mul', 'elementwise-sum']" )