|
6 | 6 | import torch
|
7 | 7 |
|
8 | 8 | from datasets import Dataset, DatasetDict
|
9 |
| -from transformers import AutoModel, AutoModelForSequenceClassification, \ |
| 9 | +from transformers import AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, \ |
10 | 10 | AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
|
11 | 11 |
|
12 | 12 | from swag_transformers.swag_bert import SwagBertConfig, SwagBertLMHeadModel, SwagBertModel, SwagBertPreTrainedModel, \
|
13 |
| - SwagBertForSequenceClassification |
| 13 | + SwagBertForSequenceClassification, SwagBertForQuestionAnswering |
14 | 14 | from swag_transformers.trainer_utils import SwagUpdateCallback
|
15 | 15 |
|
16 | 16 |
|
@@ -99,6 +99,27 @@ def test_pretrained_bert_classifier_test(self):
|
99 | 99 | logging.debug(out)
|
100 | 100 | self.assertEqual(out.logits.shape, (1, num_labels))
|
101 | 101 |
|
| 102 | + def test_pretrained_bert_qa_test(self): |
| 103 | + model = AutoModelForQuestionAnswering.from_pretrained(self.pretrained_model_name) |
| 104 | + tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name) |
| 105 | + swag_model = SwagBertForQuestionAnswering.from_base(model) |
| 106 | + swag_model.swag.collect_model(model) |
| 107 | + swag_model.sample_parameters() |
| 108 | + inputs = tokenizer( |
| 109 | + "What is context?", |
| 110 | + "Context is the interrelated conditions in which something exists or occurs.", |
| 111 | + max_length=100, |
| 112 | + truncation="only_second", |
| 113 | + stride=50, |
| 114 | + return_tensors="pt" |
| 115 | + ) |
| 116 | + logging.debug(inputs) |
| 117 | + num_positions = inputs['input_ids'].shape[1] |
| 118 | + out = swag_model.forward(**inputs) |
| 119 | + logging.debug(out) |
| 120 | + self.assertEqual(out.start_logits.shape, (1, num_positions)) |
| 121 | + self.assertEqual(out.end_logits.shape, (1, num_positions)) |
| 122 | + |
102 | 123 |
|
103 | 124 | class TestSwagBertFinetune(unittest.TestCase):
|
104 | 125 |
|
|
0 commit comments