Skip to content

Commit 627ee92

Browse files
committed
update README
1 parent f3453e9 commit 627ee92

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

README.md

+41
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,44 @@ matrix) and `scale=0.5` (recommended). For SWA, you should use
124124
* `MarianPreTrainedModel` -> `SwagMarianPreTrainedModel`
125125
* `MarianModel` -> `SwagMarianModel`
126126
* `MarianMTModel` -> `SwagMarianMTModel`
127+
128+
### Registering classes for transformers Auto methods
129+
130+
As the SWAG variants are not part of the Hugging Face libraries, you
131+
need to register them in order for `AutoModelForXXX` methods to
132+
work. For example, for `SwagBertForSequenceClassification` you need to
133+
have:
134+
135+
```python
136+
transformers.AutoConfig.register("swag_bert", SwagBertConfig)
137+
transformers.AutoModelForSequenceClassification.register(SwagBertConfig, SwagBertForSequenceClassification)
138+
```
139+
140+
Now you can load a saved model with:
141+
142+
```python
143+
transformers.AutoModelForSequenceClassification.from_pretrained(path)
144+
```
145+
146+
### Wrapping custom models
147+
148+
It is easy to wrap any model class based on the `transformers`
149+
library. For example, having `MyModel` with `MyModelConfig`, you can
150+
define the SWAG variants as follows:
151+
152+
```python
153+
from swag_transformers.base import SwagConfig, SwagModel
154+
155+
MODEL_TYPE = 'swag_mymodel'
156+
157+
class SwagMyModelConfig(SwagConfig):
158+
159+
model_type = MODEL_TYPE
160+
internal_config_class = MyModelConfig
161+
162+
class SwagMyModel(SwagModel):
163+
164+
base_model_prefix = MODEL_TYPE
165+
config_class = SwagMyModelConfig
166+
internal_model_class = MyModel
167+
```

0 commit comments

Comments
 (0)