@@ -124,3 +124,44 @@ matrix) and `scale=0.5` (recommended). For SWA, you should use
124
124
* ` MarianPreTrainedModel ` -> ` SwagMarianPreTrainedModel `
125
125
* ` MarianModel ` -> ` SwagMarianModel `
126
126
* ` MarianMTModel ` -> ` SwagMarianMTModel `
127
+
128
+ ### Registering classes for transformers Auto methods
129
+
130
+ As the SWAG variants are not part of the Hugging Face libraries, you
131
+ need to register them in order for ` AutoModelForXXX ` methods to
132
+ work. For example, for ` SwagBertForSequenceClassification ` you need to
133
+ have:
134
+
135
+ ``` python
136
+ transformers.AutoConfig.register(" swag_bert" , SwagBertConfig)
137
+ transformers.AutoModelForSequenceClassification.register(SwagBertConfig, SwagBertForSequenceClassification)
138
+ ```
139
+
140
+ Now you can load a saved model with:
141
+
142
+ ``` python
143
+ transformers.AutoModelForSequenceClassification.from_pretrained(path)
144
+ ```
145
+
146
+ ### Wrapping custom models
147
+
148
+ It is easy to wrap any model class based on the ` transformers `
149
+ library. For example, having ` MyModel ` with ` MyModelConfig ` , you can
150
+ define the SWAG variants as follows:
151
+
152
+ ``` python
153
+ from swag_transformers.base import SwagConfig, SwagModel
154
+
155
+ MODEL_TYPE = ' swag_mymodel'
156
+
157
+ class SwagMyModelConfig (SwagConfig ):
158
+
159
+ model_type = MODEL_TYPE
160
+ internal_config_class = MyModelConfig
161
+
162
+ class SwagMyModel (SwagModel ):
163
+
164
+ base_model_prefix = MODEL_TYPE
165
+ config_class = SwagMyModelConfig
166
+ internal_model_class = MyModel
167
+ ```
0 commit comments