Skip to content

Commit

Permalink
fix deprecation warning and pytest issue in unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed Jun 26, 2024
1 parent b586eb9 commit 071cd1a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 12 deletions.
6 changes: 1 addition & 5 deletions examples/marian_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ def tokenize_function(example):
inputs = [pair['de'] for pair in example['translation']]
targets = [pair['nl'] for pair in example['translation']]
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, truncation=True)

labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs

Expand Down
3 changes: 2 additions & 1 deletion tests/test_swag_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def test_pretrained_bert_tiny_classifier_test(self):
logging.debug(out)
self.assertEqual(out.logits.shape, (1, num_labels))

def _data_gen(self):
@staticmethod
def _data_gen():
yield {"text": "Hello world", "label": 0}
yield {"text": "Just some swaggering", "label": 1}
yield {"text": "Have a good day", "label": 0}
Expand Down
9 changes: 3 additions & 6 deletions tests/test_swag_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def test_pretrained_marian_tiny_test(self):
self.assertGreater(len(output), 0)
self.assertEqual(base_output, output)

def _data_gen(self):
@staticmethod
def _data_gen():
yield {"source": "India and Japan prime ministers meet in Tokyo",
"target": "Die Premierminister Indiens und Japans trafen sich in Tokio."}
yield {"source": "High on the agenda are plans for greater nuclear co-operation.",
Expand Down Expand Up @@ -136,11 +137,7 @@ def tokenize_function(example):
inputs = example['source']
targets = example['target']
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, truncation=True)

labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs

Expand Down

0 comments on commit 071cd1a

Please sign in to comment.