|
1 | 1 | import logging
|
| 2 | +import os |
2 | 3 | import unittest
|
3 | 4 | import tempfile
|
4 | 5 |
|
5 | 6 | import torch
|
6 | 7 |
|
7 |
| -from transformers import AutoTokenizer, BartForConditionalGeneration |
| 8 | +from transformers import AutoTokenizer, BartForConditionalGeneration, GenerationConfig |
8 | 9 |
|
9 | 10 | from swag_transformers.swag_bart import SwagBartConfig, SwagBartModel, SwagBartPreTrainedModel, \
|
10 | 11 | SwagBartForConditionalGeneration
|
@@ -33,45 +34,67 @@ def test_untrained(self):
|
33 | 34 | logging.debug(out)
|
34 | 35 | self.assertEqual(out.last_hidden_state.shape, (1, 2, hidden_size))
|
35 | 36 |
|
36 |
| - def test_pretrained_bart_generative(self): |
| 37 | + def pretrained_bart_generative(self, no_cov_mat): |
37 | 38 | device = "cuda" if torch.cuda.is_available() else "cpu"
|
38 | 39 | model = BartForConditionalGeneration.from_pretrained(self.pretrained_model_name)
|
39 | 40 | model.to(device)
|
40 | 41 | self.assertEqual(model.device.type, device)
|
41 |
| - swag_model = SwagBartForConditionalGeneration.from_base(model) |
| 42 | + swag_model = SwagBartForConditionalGeneration.from_base(model, no_cov_mat=no_cov_mat) |
42 | 43 | swag_model.to(device)
|
43 | 44 | self.assertEqual(swag_model.device.type, device)
|
44 |
| - tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name) |
| 45 | + tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name, clean_up_tokenization_spaces=False) |
| 46 | + |
| 47 | + gen_config = GenerationConfig.from_model_config(model.config) |
| 48 | + logging.warning(gen_config) |
| 49 | + gen_config.max_new_tokens = 10 |
| 50 | + logging.warning(gen_config) |
45 | 51 |
|
46 | 52 | swag_model.swag.collect_model(model)
|
47 | 53 | swag_model.sample_parameters()
|
| 54 | + # has to be updated manually when using collect_model directly |
| 55 | + swag_model.config.cov_mat_rank = swag_model.swag.cov_mat_rank |
48 | 56 |
|
49 | 57 | # Test forward
|
50 |
| - base_out = model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]])) |
51 |
| - out = swag_model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]])) |
52 |
| - self.assertTrue(torch.allclose(base_out.logits, out.logits)) |
| 58 | + base_fwd_out = model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]])) |
| 59 | + swag_fwd_out = swag_model.forward(input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]])) |
| 60 | + self.assertTrue(torch.allclose(base_fwd_out.logits, swag_fwd_out.logits)) |
53 | 61 |
|
54 | 62 | # Test generate
|
55 | 63 | example = "I have no BART and I must generate"
|
| 64 | + torch.manual_seed(123) |
56 | 65 | batch = tokenizer(example, return_tensors="pt")
|
57 |
| - base_generated_ids = model.generate(batch["input_ids"]) |
| 66 | + base_generated_ids = model.generate(batch["input_ids"], generation_config=gen_config) |
| 67 | + # max_length=20, num_beams=1, do_sample=False, early_stopping=False |
58 | 68 | base_out = tokenizer.batch_decode(base_generated_ids, skip_special_tokens=True)
|
59 |
| - generated_ids = swag_model.generate(batch["input_ids"]) |
| 69 | + logging.warning(base_out) |
| 70 | + |
| 71 | + generated_ids = swag_model.generate(batch["input_ids"], generation_config=gen_config) |
60 | 72 | out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
61 |
| - logging.info(base_out) |
62 |
| - logging.info(out) |
| 73 | + logging.warning(out) |
63 | 74 | self.assertEqual(base_out, out)
|
64 | 75 |
|
65 | 76 | # Test saving & loading
|
66 | 77 | with tempfile.TemporaryDirectory() as tempdir:
|
67 | 78 | swag_model.save_pretrained(tempdir)
|
| 79 | + logging.warning(os.listdir(tempdir)) |
| 80 | + with open(os.path.join(tempdir, 'config.json'), 'r') as fobj: |
| 81 | + logging.warning(fobj.read()) |
68 | 82 | stored_model = SwagBartForConditionalGeneration.from_pretrained(tempdir).to(device)
|
69 | 83 |
|
70 |
| - generated_ids = stored_model.generate(batch["input_ids"]) |
| 84 | + stored_fwd_out = stored_model.forward( |
| 85 | + input_ids=torch.tensor([[3, 14]]), decoder_input_ids=torch.tensor([[1, 2, 4]])) |
| 86 | + self.assertTrue(torch.allclose(swag_fwd_out.logits, stored_fwd_out.logits)) |
| 87 | + |
| 88 | + generated_ids = stored_model.generate(batch["input_ids"], generation_config=gen_config) |
71 | 89 | out = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
72 |
| - logging.info(out) |
73 | 90 | self.assertEqual(base_out, out)
|
74 | 91 |
|
| 92 | + def test_pretrained_bart_generative_no_cov(self): |
| 93 | + self.pretrained_bart_generative(no_cov_mat=True) |
| 94 | + |
| 95 | + def test_pretrained_bart_generative_with_cov(self): |
| 96 | + self.pretrained_bart_generative(no_cov_mat=False) |
| 97 | + |
75 | 98 |
|
76 | 99 | if __name__ == "__main__":
|
77 | 100 | logging.basicConfig(level=logging.INFO)
|
|
0 commit comments