@@ -22,7 +22,7 @@ See also [examples](./examples).
22
22
BERT model, sequence classification task:
23
23
24
24
1 . Load pretrained Bert model by ` base_model = AutoModelForSequenceClassification.from_pretrained(name_or_path) `
25
- 2 . Initialize SWAG model by ` swag_model = SwagBertForSequenceClassification.from_base(base_model) `
25
+ 2 . Initialize SWAG model by ` swag_model = SwagBertForSequenceClassification.from_base(base_model, no_cov_mat=False ) `
26
26
3 . Initialize SWAG callback object ` swag_callback = SwagUpdateCallback(swag_model) `
27
27
4 . Initialize ` transformers.Trainer ` with the ` base_model ` as model and ` swag_callback ` in callbacks.
28
28
5 . Train the model (` trainer.train() ` )
@@ -35,6 +35,24 @@ For collecting the SWAG parameters, two possible schedules are supported:
35
35
* After the end of each training epoch (default, ` collect_steps = 0 ` for ` SwagUpdateCallback ` )
36
36
* After each N training steps (set ` collect_steps > 0 ` for ` SwagUpdateCallback ` )
37
37
38
+ ### SWA versus SWAG
39
+
40
+ The library supports both SWA (stochastic weight averaging without
41
+ covariance estimation) and SWAG (stochastic weight averaging with
42
+ Gaussian covariance estimation). The method is selected by the
43
+ ` no_cov_mat ` attribute when initializing the model
44
+ (e.g. ` SwagModel.from_base(model, no_cov_mat=True) ` ). The default
45
+ value ` True ` corresponds to SWA, and you need to explicitly set
46
+ ` no_cov_mat=False ` to activate SWAG.
47
+
48
+ With SWAG, the ` max_num_models ` option controls the maximum rank of
49
+ the covariance matrix. The rank is increased by each parameter
50
+ collection step until the maximum is reached. The current rank is
51
+ stored in ` model.swag.cov_mat_rank ` and automatically updated to
52
+ ` model.config.cov_mat_rank ` when using ` SwagUpdateCallback ` . If you
53
+ call ` model.swag.collect_model() ` manually, you should also update the
54
+ configuration accordingly before saving the model.
55
+
38
56
### Sampling model parameters
39
57
40
58
After ` swag_model ` is trained or fine-tuned as described above,
0 commit comments