Skip to content

Commit ae2fc4f

Browse files
committed
add wrapper for BertForQuestionAnswering
1 parent 639f63f commit ae2fc4f

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ matrix) and `scale=0.5` (recommended). For SWA, you should use
8888
* `BertModel` -> `SwagBertModel`
8989
* `BertLMHeadModel` -> `SwagBertLMHeadModel`
9090
* `BertForSequenceClassification` -> `SwagBertForSequenceClassification`
91+
* `BertForQuestionAnswering` -> `SwagBertForQuestionAnswering`
9192
* BART (bidirectional encoder + causal decoder)
9293
* `BartPreTrainedModel` -> `SwagBartPreTrainedModel`
9394
* `BartModel` -> `SwagBartModel`

src/swag_transformers/swag_bert.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from transformers import BertConfig, BertLMHeadModel, BertModel, BertPreTrainedModel, \
6-
BertForSequenceClassification
6+
BertForSequenceClassification, BertForQuestionAnswering
77

88
from .base import SwagConfig, SwagPreTrainedModel, SwagModel, SampleLogitsMixin
99

@@ -47,3 +47,9 @@ class SwagBertLMHeadModel(SwagBertModel):
4747
"""SWAG BERT model with LM head"""
4848

4949
internal_model_class = BertLMHeadModel
50+
51+
52+
class SwagBertForQuestionAnswering(SwagBertModel):
53+
"""SWAG BERT model for question answering"""
54+
55+
internal_model_class = BertForQuestionAnswering

tests/test_swag_bert.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import torch
77

88
from datasets import Dataset, DatasetDict
9-
from transformers import AutoModel, AutoModelForSequenceClassification, \
9+
from transformers import AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, \
1010
AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments
1111

1212
from swag_transformers.swag_bert import SwagBertConfig, SwagBertLMHeadModel, SwagBertModel, SwagBertPreTrainedModel, \
13-
SwagBertForSequenceClassification
13+
SwagBertForSequenceClassification, SwagBertForQuestionAnswering
1414
from swag_transformers.trainer_utils import SwagUpdateCallback
1515

1616

@@ -99,6 +99,27 @@ def test_pretrained_bert_classifier_test(self):
9999
logging.debug(out)
100100
self.assertEqual(out.logits.shape, (1, num_labels))
101101

102+
def test_pretrained_bert_qa_test(self):
103+
model = AutoModelForQuestionAnswering.from_pretrained(self.pretrained_model_name)
104+
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name)
105+
swag_model = SwagBertForQuestionAnswering.from_base(model)
106+
swag_model.swag.collect_model(model)
107+
swag_model.sample_parameters()
108+
inputs = tokenizer(
109+
"What is context?",
110+
"Context is the interrelated conditions in which something exists or occurs.",
111+
max_length=100,
112+
truncation="only_second",
113+
stride=50,
114+
return_tensors="pt"
115+
)
116+
logging.debug(inputs)
117+
num_positions = inputs['input_ids'].shape[1]
118+
out = swag_model.forward(**inputs)
119+
logging.debug(out)
120+
self.assertEqual(out.start_logits.shape, (1, num_positions))
121+
self.assertEqual(out.end_logits.shape, (1, num_positions))
122+
102123

103124
class TestSwagBertFinetune(unittest.TestCase):
104125

0 commit comments

Comments
 (0)