Skip to content

Commit 3d05260

Browse files
committed
update README
1 parent ca3e2ad commit 3d05260

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

README.md

+19-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ See also [examples](./examples).
2222
BERT model, sequence classification task:
2323

2424
1. Load pretrained Bert model by `base_model = AutoModelForSequenceClassification.from_pretrained(name_or_path)`
25-
2. Initialize SWAG model by `swag_model = SwagBertForSequenceClassification.from_base(base_model)`
25+
2. Initialize SWAG model by `swag_model = SwagBertForSequenceClassification.from_base(base_model, no_cov_mat=False)`
2626
3. Initialize SWAG callback object `swag_callback = SwagUpdateCallback(swag_model)`
2727
4. Initialize `transformers.Trainer` with the `base_model` as model and `swag_callback` in callbacks.
2828
5. Train the model (`trainer.train()`)
@@ -35,6 +35,24 @@ For collecting the SWAG parameters, two possible schedules are supported:
3535
* After the end of each training epoch (default, `collect_steps = 0` for `SwagUpdateCallback`)
3636
* After each N training steps (set `collect_steps > 0` for `SwagUpdateCallback`)
3737

38+
### SWA versus SWAG
39+
40+
The library supports both SWA (stochastic weight averaging without
41+
covariance estimation) and SWAG (stochastic weight averaging with
42+
Gaussian covariance estimation). The method is selected by the
43+
`no_cov_mat` attribute when initializing the model
44+
(e.g. `SwagModel.from_base(model, no_cov_mat=True)`). The default
45+
value `True` corresponds to SWA, and you need to explicitly set
46+
`no_cov_mat=False` to activate SWAG.
47+
48+
With SWAG, the `max_num_models` option controls the maximum rank of
49+
the covariance matrix. The rank is increased by each parameter
50+
collection step until the maximum is reached. The current rank is
51+
stored in `model.swag.cov_mat_rank` and automatically updated to
52+
`model.config.cov_mat_rank` when using `SwagUpdateCallback`. If you
53+
call `model.swag.collect_model()` manually, you should also update the
54+
configuration accordingly before saving the model.
55+
3856
### Sampling model parameters
3957

4058
After `swag_model` is trained or fine-tuned as described above,

0 commit comments

Comments
 (0)