-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
187 lines (143 loc) · 5.85 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""SWAG wrapper base classes"""
import copy
import functools
import logging
from typing import Type
import torch
from transformers import PreTrainedModel, PretrainedConfig
from swag.posteriors.swag import SWAG
logger = logging.getLogger(__name__)
class SwagConfig(PretrainedConfig):
"""Base configuration class for SWAG models
For using this class, inherit it and define the following class
attributes:
- model_type: string
- internal_config_class: class inherited from PretrainedConfig
"""
internal_config_class: Type[PretrainedConfig] = PretrainedConfig
def __init__(
self,
internal_model_config: dict = None,
no_cov_mat: bool = True,
max_num_models: int = 20,
var_clamp: float = 1e-30,
**kwargs
):
super().__init__()
if internal_model_config:
self.internal_model_config = internal_model_config
else:
internal_config = self.internal_config_class(**kwargs)
self.internal_model_config = internal_config.to_dict()
self.no_cov_mat = no_cov_mat
self.max_num_models = max_num_models
self.var_clamp = var_clamp
@classmethod
def from_config(cls, base_config: PretrainedConfig, **kwargs):
"""Initialize from existing PretrainedConfig"""
config = cls(**kwargs)
config.internal_model_config = base_config.to_dict()
return config
def update_internal_config(self, base_config: PretrainedConfig):
"""Update internal config from base_config"""
self.internal_model_config = base_config.to_dict()
# Copy some things to the top level
if base_config.problem_type is not None:
self.problem_type = base_config.problem_type
class SwagPreTrainedModel(PreTrainedModel):
"""Base class for SWAG models wrapping PreTrainedModel
For using this class, inherit it and define the following class
attributes:
- base_model_prefix: string
- config_class: class inherited from PretrainedConfig
- internal_model_class: class inherited from PreTrainedModel
"""
config_class: Type[SwagConfig] = SwagConfig
internal_model_class: Type[PreTrainedModel] = PreTrainedModel
def __init__(self, config):
super().__init__(config)
self.swag = SWAG(
base=self.new_base_model,
no_cov_mat=config.no_cov_mat,
max_num_models=config.max_num_models,
var_clamp=config.var_clamp,
config=config.internal_config_class(**config.internal_model_config)
)
self.post_init()
@classmethod
def new_base_model(cls, *args, **kwargs):
"""Return new model of the base class
Any arguments are passed to the base class constructor.
"""
model = cls.internal_model_class(*args, **kwargs)
model.tie_weights()
return model
def _init_weights(self, module):
self.swag.base._init_weights(module)
def sample_parameters(self):
"""Sample new model parameters"""
self.swag.sample()
class SwagModel(SwagPreTrainedModel):
"""Base class for SWAG models
For using this class, inherit it and define the following class
attributes:
- base_model_prefix: string
- config_class: class inherited from PretrainedConfig
- internal_model_class: class inherited from PreTrainedModel
"""
def __init__(self, config, base_model=None):
super().__init__(config)
if base_model:
self.swag = SWAG(
base=functools.partial(self._base_model_copy, base_model),
no_cov_mat=config.no_cov_mat,
max_num_models=config.max_num_models,
var_clamp=config.var_clamp
)
self.prepare_inputs_for_generation = self.swag.base.prepare_inputs_for_generation
self.generate = self.swag.base.generate
@staticmethod
def _base_model_copy(model, *args, **kwargs):
"""Return deep copy of the model ignoring other arguments"""
# Has to be copied, otherwise SWAG would initialize parameters
# of the original model to zero
model = copy.deepcopy(model)
model.tie_weights()
return model
@classmethod
def from_base(cls, base_model: PreTrainedModel, **kwargs):
"""Initialize from existing PreTrainedModel"""
config = cls.config_class.from_config(base_model.config, **kwargs)
swag_model = cls(config, base_model=base_model)
return swag_model
def forward(self, *args, **kwargs):
"""Call forward pass from the base model"""
return self.swag.forward(*args, **kwargs)
@classmethod
def can_generate(cls) -> bool:
return cls.internal_model_class.can_generate()
def prepare_inputs_for_generation(self, *args, **kwargs):
return self.swag.base.prepare_inputs_for_generation(*args, **kwargs)
def generate(self, *args, **kwargs):
return self.swag.base.generate(*args, **kwargs)
class SampleLogitsMixin:
"""Mixin class for classification models providing get_logits() method using SWAG"""
def get_logits(
self, *args, num_predictions=None, scale=1.0, cov=True, block=False, **kwargs
):
"""Sample model parameters num_predictions times and get logits for the input
Results in a tensor of size batch_size x num_predictions x output_size.
"""
if num_predictions is None:
sample = False
num_predictions = 1
else:
sample = True
logits = []
for _ in range(num_predictions):
if sample:
self.sample_parameters(scale=scale, cov=cov, block=block)
out = self.forward(*args, **kwargs)
logits.append(out.logits)
logits = torch.permute(torch.stack(logits), (1, 0, 2)) # [batch_size, num_predictions, output_size]
return logits