Skip to content

Commit e971f28

Browse files
authored
improve test code and transformers support (#5)
* add setting random seed to examples * add support for transformers >= 4.43 * fix tests
1 parent 627ee92 commit e971f28

6 files changed

+21
-18
lines changed

examples/bert_snli.py

+4
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@ def main():
2727
help="number of steps between collecting parameters; set to zero for per epoch updates")
2828
parser.add_argument("--learning-rate", type=float, default=2e-5, help="learning rate")
2929
parser.add_argument("--swag-modules", type=str, action='append', help="restrict SWAG to modules matching given prefix(es)")
30+
parser.add_argument("--seed", type=int, default=None, help="set random seed")
3031
args = parser.parse_args()
3132

3233
if args.device:
3334
device = args.device
3435
else:
3536
device = "cuda" if torch.cuda.is_available() else "cpu"
3637

38+
if args.seed is not None:
39+
transformers.set_seed(args.seed)
40+
3741
tokenizer = transformers.AutoTokenizer.from_pretrained(args.base_model, cache_dir=args.model_cache_dir)
3842
model = transformers.AutoModelForSequenceClassification.from_pretrained(
3943
args.base_model, num_labels=3, cache_dir=args.model_cache_dir)

examples/marian_mt.py

+4
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@ def main():
2424
parser.add_argument("--collect-steps", type=int, default=100, help="number of steps between collecting parameters")
2525
parser.add_argument("--learning-rate", type=float, default=2e-5, help="learning rate")
2626
parser.add_argument("--swag-modules", type=str, action='append', help="restrict SWAG to modules matching given prefix(es)")
27+
parser.add_argument("--seed", type=int, default=None, help="set random seed")
2728
args = parser.parse_args()
2829

2930
if args.device:
3031
device = args.device
3132
else:
3233
device = "cuda" if torch.cuda.is_available() else "cpu"
3334

35+
if args.seed is not None:
36+
transformers.set_seed(args.seed)
37+
3438
tokenizer = transformers.AutoTokenizer.from_pretrained(args.base_model)
3539
model = transformers.MarianMTModel.from_pretrained(args.base_model)
3640
model.to(device)

setup.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
packages=find_packages(where="src"),
1919
package_dir={"": "src"},
2020
install_requires=[
21-
"transformers>=4.30",
22-
"transformers[torch]>=4.30,<4.43",
23-
"swa_gaussian>=0.1.8"
21+
"transformers[torch]>=4.30",
22+
"swa_gaussian>=0.1.9"
2423
],
2524
extras_require={
2625
"test": ["datasets", "pytest", "sentencepiece"]

tests/test_swag_bart.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
class TestSwagBart(unittest.TestCase):
1515

16-
pretrained_model_name = 'Finnish-NLP/bart-small-finnish'
17-
# pretrained_model_name = 'sshleifer/bart-tiny-random'
16+
# pretrained_model_name = 'Finnish-NLP/bart-small-finnish'
17+
pretrained_model_name = 'sshleifer/bart-tiny-random'
1818

1919
def test_untrained(self):
2020
hidden_size = 240
@@ -45,12 +45,11 @@ def pretrained_bart_generative(self, no_cov_mat):
4545
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name, clean_up_tokenization_spaces=False)
4646

4747
gen_config = GenerationConfig.from_model_config(model.config)
48-
logging.warning(gen_config)
4948
gen_config.max_new_tokens = 10
50-
logging.warning(gen_config)
49+
logging.debug(gen_config)
5150

5251
swag_model.swag.collect_model(model)
53-
swag_model.sample_parameters(cov=not no_cov_mat)
52+
swag_model.sample_parameters(cov=not no_cov_mat, seed=1234)
5453
# has to be updated manually when using collect_model directly
5554
swag_model.config.cov_mat_rank = swag_model.swag.cov_mat_rank
5655

@@ -61,29 +60,26 @@ def pretrained_bart_generative(self, no_cov_mat):
6160

6261
# Test generate
6362
example = "I have no BART and I must generate"
64-
torch.manual_seed(123)
6563
batch = tokenizer(example, return_tensors="pt")
6664
base_generated_ids = model.generate(batch["input_ids"], generation_config=gen_config)
67-
# max_length=20, num_beams=1, do_sample=False, early_stopping=False
6865
base_out = tokenizer.batch_decode(base_generated_ids, skip_special_tokens=True)
69-
logging.warning(base_out)
7066

7167
generated_ids = swag_model.generate(batch["input_ids"], generation_config=gen_config)
7268
out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
73-
logging.warning(out)
7469
self.assertEqual(base_out, out)
7570

7671
# Test saving & loading
7772
with tempfile.TemporaryDirectory() as tempdir:
7873
swag_model.save_pretrained(tempdir)
79-
logging.warning(os.listdir(tempdir))
80-
with open(os.path.join(tempdir, 'config.json'), 'r') as fobj:
81-
logging.warning(fobj.read())
74+
logging.debug(os.listdir(tempdir))
75+
with open(os.path.join(tempdir, 'config.json'), 'r', encoding='utf8') as fobj:
76+
logging.debug(fobj.read())
8277
stored_model = SwagBartForConditionalGeneration.from_pretrained(tempdir).to(device)
8378

79+
stored_model.sample_parameters(cov=not no_cov_mat, seed=1234)
8480
stored_fwd_out = stored_model.forward(
8581
input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]]))
86-
self.assertTrue(torch.allclose(swag_fwd_out.logits, stored_fwd_out.logits))
82+
self.assertTrue(torch.allclose(swag_fwd_out.logits, stored_fwd_out.logits, atol=1e-06))
8783

8884
generated_ids = stored_model.generate(batch["input_ids"], generation_config=gen_config)
8985
out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

tests/test_swag_bert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def tokenize_function(example):
203203
training_args,
204204
train_dataset=tokenized_datasets["train"],
205205
data_collator=data_collator,
206-
tokenizer=tokenizer,
206+
processing_class=tokenizer,
207207
callbacks=[SwagUpdateCallback(swag_model)]
208208
)
209209
trainer.train()

tests/test_swag_marian.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def tokenize_function(example):
185185
training_args,
186186
train_dataset=tokenized_datasets["train"],
187187
data_collator=data_collator,
188-
tokenizer=tokenizer,
188+
processing_class=tokenizer,
189189
callbacks=[SwagUpdateCallback(swag_model, collect_steps=2)]
190190
)
191191
trainer.train()

0 commit comments

Comments
 (0)