Skip to content

Commit

Permalink
update README
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed Oct 9, 2024
1 parent 9f0f137 commit 261f5c7
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,20 @@ For collecting the SWAG parameters, two possible schedules are supported:
* After the end of each training epoch (default, `collect_steps = 0` for `SwagUpdateCallback`)
* After each N training steps (set `collect_steps > 0` for `SwagUpdateCallback`)

### SWA versus SWAG
### SWA, SWAG-Diagonal, and SWAG

The library supports both SWA (stochastic weight averaging without
covariance estimation) and SWAG (stochastic weight averaging with
Gaussian covariance estimation). The method is selected by the
`no_cov_mat` attribute when initializing the model
(e.g. `SwagModel.from_base(model, no_cov_mat=True)`). The default
value `True` corresponds to SWA, and you need to explicitly set
`no_cov_mat=False` to activate SWAG.
The library supports both SWA (stochastic weight averaging) and two
variants of SWAG (SWA-Gaussian): SWAG-Diagonal that uses diagonal
covariance and "full" SWAG that does low-rank covariance matrix
estimation.

The method is selected by the `no_cov_mat` attribute when initializing
the model (e.g. `SwagModel.from_base(model, no_cov_mat=True)`). The
default value `True` works only with SWAG-Diagonal and SWA, and you
need to explicitly set `no_cov_mat=False` to activate the low-rank
covariance estimation of SWAG. Note that you can also test SWA and
SWAG-Diagonal methods when the model is trained with
`no_cov_mat=False` (see the next section).

With SWAG, the `max_num_models` option controls the maximum rank of
the covariance matrix. The rank is increased by each parameter
Expand All @@ -66,6 +71,16 @@ class provides the convenience method `get_logits()` that samples the
parameters and makes a new prediction `num_predictions` times, and
returns the logit values in a tensor.

Note that both for `sample_parameters()` and `get_logits()` the
default keyword arguments are suitable only for SWAG-Diagonal. For
SWAG, you should use `cov=True` (required to use the covariance
matrix) and `scale=0.5` (recommended). For SWA, you should use
`cov=False` and `scale=0`. To summarize:

* SWA: `scale=0`, `cov=False`
* SWAG-Diagonal: `scale=1`, `cov=False` (defaults)
* SWAG: `scale=0.5`, `cov=True` (`no_cov_mat=False` required for the model)

### Currently supported models

* BERT (bidirectional encoder)
Expand Down

0 comments on commit 261f5c7

Please sign in to comment.