diff --git a/README.md b/README.md index b2e65d7..d1ce50f 100644 --- a/README.md +++ b/README.md @@ -35,15 +35,20 @@ For collecting the SWAG parameters, two possible schedules are supported: * After the end of each training epoch (default, `collect_steps = 0` for `SwagUpdateCallback`) * After each N training steps (set `collect_steps > 0` for `SwagUpdateCallback`) -### SWA versus SWAG +### SWA, SWAG-Diagonal, and SWAG -The library supports both SWA (stochastic weight averaging without -covariance estimation) and SWAG (stochastic weight averaging with -Gaussian covariance estimation). The method is selected by the -`no_cov_mat` attribute when initializing the model -(e.g. `SwagModel.from_base(model, no_cov_mat=True)`). The default -value `True` corresponds to SWA, and you need to explicitly set -`no_cov_mat=False` to activate SWAG. +The library supports both SWA (stochastic weight averaging) and two +variants of SWAG (SWA-Gaussian): SWAG-Diagonal that uses diagonal +covariance and "full" SWAG that does low-rank covariance matrix +estimation. + +The method is selected by the `no_cov_mat` attribute when initializing +the model (e.g. `SwagModel.from_base(model, no_cov_mat=True)`). The +default value `True` works only with SWAG-Diagonal and SWA, and you +need to explicitly set `no_cov_mat=False` to activate the low-rank +covariance estimation of SWAG. Note that you can also test SWA and +SWAG-Diagonal methods when the model is trained with +`no_cov_mat=False` (see the next section). With SWAG, the `max_num_models` option controls the maximum rank of the covariance matrix. The rank is increased by each parameter @@ -66,6 +71,16 @@ class provides the convenience method `get_logits()` that samples the parameters and makes a new prediction `num_predictions` times, and returns the logit values in a tensor. +Note that both for `sample_parameters()` and `get_logits()` the +default keyword arguments are suitable only for SWAG-Diagonal. For +SWAG, you should use `cov=True` (required to use the covariance +matrix) and `scale=0.5` (recommended). For SWA, you should use +`cov=False` and `scale=0`. To summarize: + +* SWA: `scale=0`, `cov=False` +* SWAG-Diagonal: `scale=1`, `cov=False` (defaults) +* SWAG: `scale=0.5`, `cov=True` (`no_cov_mat=False` required for the model) + ### Currently supported models * BERT (bidirectional encoder)