Skip to content

Commit

Permalink
add linting changes required by linter
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 committed Nov 1, 2023
1 parent f24ea77 commit 6ae9a9a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import numpy as np
import pandas as pd
import torch
import transformers
from exp_outputs import (
config_dllogger,
creates_output_dir,
Expand All @@ -33,14 +32,15 @@
from merlin.io import Dataset
from merlin.schema import Tags
from transf_exp_args import DataArguments, ModelArguments, TrainingArguments
from transformers import HfArgumentParser, set_seed
from transformers.trainer_utils import is_main_process

import transformers
import transformers4rec.torch as t4r
from merlin_standard_lib import Schema
from transformers import HfArgumentParser, set_seed
from transformers4rec.torch import Trainer
from transformers4rec.torch.utils.data_utils import MerlinDataLoader
from transformers4rec.torch.utils.examples_utils import wipe_memory
from transformers.trainer_utils import is_main_process

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -224,7 +224,7 @@ def mask_last_interaction(x):
logger.info(f"Recall@10 of manually masked test data = {str(recall_10)}")
output_file = os.path.join(training_args.output_dir, "eval_results_over_time.txt")
with open(output_file, "a") as writer:
writer.write(f"\n***** Recall@10 of simulated inference = {recall_10} *****\n")
writer.write(f"\n***** Recall@10 of simulated inference = {recall_10} *****\n")
# Verify that the recall@10 from train.evaluate() matches the recall@10 calculated manually
if not isinstance(input_module.masking, t4r.masking.PermutationLanguageModeling):
# TODO fix inference discrepancy for permutation language modeling
Expand Down
2 changes: 1 addition & 1 deletion transformers4rec/torch/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(self, inputs, training=False, testing=False, **kwargs):
output = seq_rep + context_rep
else:
raise ValueError(
f"The aggregation {self.fusion_aggregation} is not supported,"
f"The aggregation {self.fusion_aggregation} is not supported, "
f"please select one of the following aggregations "
f"['concat', 'elementwise-mul', 'elementwise-sum']"
)
Expand Down

0 comments on commit 6ae9a9a

Please sign in to comment.