Skip to content

Commit ca3e2ad

Browse files
committed
improve unit tests for BART
1 parent cfd7c67 commit ca3e2ad

File tree

1 file changed

+36
-13
lines changed

1 file changed

+36
-13
lines changed

tests/test_swag_bart.py

+36-13
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import logging
2+
import os
23
import unittest
34
import tempfile
45

56
import torch
67

7-
from transformers import AutoTokenizer, BartForConditionalGeneration
8+
from transformers import AutoTokenizer, BartForConditionalGeneration, GenerationConfig
89

910
from swag_transformers.swag_bart import SwagBartConfig, SwagBartModel, SwagBartPreTrainedModel, \
1011
SwagBartForConditionalGeneration
@@ -33,45 +34,67 @@ def test_untrained(self):
3334
logging.debug(out)
3435
self.assertEqual(out.last_hidden_state.shape, (1, 2, hidden_size))
3536

36-
def test_pretrained_bart_generative(self):
37+
def pretrained_bart_generative(self, no_cov_mat):
3738
device = "cuda" if torch.cuda.is_available() else "cpu"
3839
model = BartForConditionalGeneration.from_pretrained(self.pretrained_model_name)
3940
model.to(device)
4041
self.assertEqual(model.device.type, device)
41-
swag_model = SwagBartForConditionalGeneration.from_base(model)
42+
swag_model = SwagBartForConditionalGeneration.from_base(model, no_cov_mat=no_cov_mat)
4243
swag_model.to(device)
4344
self.assertEqual(swag_model.device.type, device)
44-
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name)
45+
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name, clean_up_tokenization_spaces=False)
46+
47+
gen_config = GenerationConfig.from_model_config(model.config)
48+
logging.warning(gen_config)
49+
gen_config.max_new_tokens = 10
50+
logging.warning(gen_config)
4551

4652
swag_model.swag.collect_model(model)
4753
swag_model.sample_parameters()
54+
# has to be updated manually when using collect_model directly
55+
swag_model.config.cov_mat_rank = swag_model.swag.cov_mat_rank
4856

4957
# Test forward
50-
base_out = model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]]))
51-
out = swag_model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]]))
52-
self.assertTrue(torch.allclose(base_out.logits, out.logits))
58+
base_fwd_out = model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]]))
59+
swag_fwd_out = swag_model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]]))
60+
self.assertTrue(torch.allclose(base_fwd_out.logits, swag_fwd_out.logits))
5361

5462
# Test generate
5563
example = "I have no BART and I must generate"
64+
torch.manual_seed(123)
5665
batch = tokenizer(example, return_tensors="pt")
57-
base_generated_ids = model.generate(batch["input_ids"])
66+
base_generated_ids = model.generate(batch["input_ids"], generation_config=gen_config)
67+
# max_length=20, num_beams=1, do_sample=False, early_stopping=False
5868
base_out = tokenizer.batch_decode(base_generated_ids, skip_special_tokens=True)
59-
generated_ids = swag_model.generate(batch["input_ids"])
69+
logging.warning(base_out)
70+
71+
generated_ids = swag_model.generate(batch["input_ids"], generation_config=gen_config)
6072
out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
61-
logging.info(base_out)
62-
logging.info(out)
73+
logging.warning(out)
6374
self.assertEqual(base_out, out)
6475

6576
# Test saving & loading
6677
with tempfile.TemporaryDirectory() as tempdir:
6778
swag_model.save_pretrained(tempdir)
79+
logging.warning(os.listdir(tempdir))
80+
with open(os.path.join(tempdir, 'config.json'), 'r') as fobj:
81+
logging.warning(fobj.read())
6882
stored_model = SwagBartForConditionalGeneration.from_pretrained(tempdir).to(device)
6983

70-
generated_ids = stored_model.generate(batch["input_ids"])
84+
stored_fwd_out = stored_model.forward(
85+
input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]]))
86+
self.assertTrue(torch.allclose(swag_fwd_out.logits, stored_fwd_out.logits))
87+
88+
generated_ids = stored_model.generate(batch["input_ids"], generation_config=gen_config)
7189
out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
72-
logging.info(out)
7390
self.assertEqual(base_out, out)
7491

92+
def test_pretrained_bart_generative_no_cov(self):
93+
self.pretrained_bart_generative(no_cov_mat=True)
94+
95+
def test_pretrained_bart_generative_with_cov(self):
96+
self.pretrained_bart_generative(no_cov_mat=False)
97+
7598

7699
if __name__ == "__main__":
77100
logging.basicConfig(level=logging.INFO)

0 commit comments

Comments
 (0)