-
Notifications
You must be signed in to change notification settings - Fork 0
/
regression_network.py
271 lines (222 loc) · 13.4 KB
/
regression_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""Example for a network that uses our feature steering method.
Essentially, all of the method's magic happens in the feat_steering_loss(...)
function, which calculates the feature steering portion of the loss that is
later added to the standard maximum-likelihood loss.
Implementation-wise, it is very important to make sure that the feature steering
part of the loss is calculated in a differentiable manner. Depending on the
types of variables, this can be difficult when estimating the CMI.
"""
import torch
from torch import nn
from torch.utils.data import DataLoader
from contextual_decomposition import get_cd_1d_by_modules
from mixed_cmi_estimator import mixed_cmi_model
class RegressionNetwork(nn.Module):
def __init__(self, input_shape: torch.Tensor, n_hidden_layers:int =2, hidden_dim_size:int=32, device:str='cpu'):
"""Creates a new RegressionNetwork.
The regression network consists of an input layer, one or more hidden
layers with the same size each and has a scalar output. All hidden layers
have ReLU activation function.
The linear layers are initialized with Xavier initialization for the
weights and zero initialization for the biases.
Args:
input_shape (torch.Tensor): Input shape of the network.
n_hidden_layers (int, optional): Number of hidden layers. Defaults
to 2.
hidden_dim_size (int, optional): Size of the hidden linear layers.
Defaults to 32.
device (str, optional): Device used by PyTorch to store tensors for
computation. Defaults to 'cpu'.
Raises:
ValueError: The RegressionNetwork has to have at least one hidden
layer.
"""
super().__init__()
self.device = device
# The network always has at least one hidden layer (input_shape -> 32).
# Make sure that n_hidden_layers is valid.
if n_hidden_layers < 1:
raise ValueError("The network cannot have less than 1 hidden layer.")
# Generate and initialize hidden layers.
# Note: we only need to generate n_hidden_layers-1 hidden layers!
lin_layers = [nn.Linear(in_features=hidden_dim_size, out_features=hidden_dim_size)] * (n_hidden_layers - 1)
for lin_layer in lin_layers:
nn.init.xavier_uniform_(lin_layer.weight)
nn.init.zeros_(lin_layer.bias)
relus = [nn.ReLU()] * len(lin_layers)
# Generate and intialize first and last layer.
input_layer = nn.Linear(input_shape, hidden_dim_size)
output_layer = nn.Linear(hidden_dim_size, 1)
for lin_layer in [input_layer, output_layer]:
nn.init.xavier_uniform_(lin_layer.weight)
nn.init.zeros_(lin_layer.bias)
# Combine layers to a model.
modules = [ nn.Flatten(),
input_layer,
nn.ReLU(),
*[z for tuple in zip(lin_layers, relus) for z in tuple],
output_layer,
]
self.layers = nn.Sequential(*modules)
self.to(device)
def forward(self, x):
return self.layers(x)
def feat_steering_loss(self, inputs:torch.Tensor, targets: torch.Tensor, outputs: torch.Tensor, feat_steering_config:dict =None) -> torch.Tensor:
"""Returns feature steering loss.
This function is where all the magic of our method takes place.
The feature steering part of the loss depends on the configuration provided in
feat_steering_config. Here, you can specify the steering mode (none,
loss_l1, loss_l1) and the feature attribution mode (contextual_decomposition,
cmi). If cmi is specified, the cmi estimate is transformed as described
in the paper. Also, you can specify which features shall be encouraged and
discouraged (each as a list of indices).
Args:
inputs (torch.Tensor): Inputs to the network.
targets (torch.Tensor): Targets for the specified inputs.
outputs (torch.Tensor): Outputs generated by the network for the
specified inputs.
feat_steering_config (dict, optional): Configuration how feature steering
shall be performed. For more details, see above. Defaults to None.
Raises:
ValueError: Invalid norm for feature steering.
ValueError: Invalid feature attribution mode.
Returns:
torch.Tensor: Differentiable feature steering loss per sample.
"""
# Get configuration for feature steering.
# Do not perform feature steering if it is not desired.
if feat_steering_config["steering_mode"] == "none":
return torch.tensor(0.0)
elif not feat_steering_config["steering_mode"] in ["loss_l1", "loss_l2"]:
raise ValueError("The feature steering mode is invalid.")
if not feat_steering_config["attrib_mode"] in ["contextual_decomposition", "cmi"]:
raise ValueError("The feature attribution mode is invalid.")
feat_to_encourage, feat_to_discourage = feat_steering_config["encourage"], feat_steering_config["discourage"]
# Feature attribution.
if feat_steering_config["attrib_mode"] == "contextual_decomposition":
scores_feat_to_encourage, _ = get_cd_1d_by_modules(self.layers, inputs, feat_to_encourage, device=self.device)
scores_feat_to_discourage, _ = get_cd_1d_by_modules(self.layers, inputs, feat_to_discourage, device=self.device)
elif feat_steering_config["attrib_mode"] == "cmi":
# Estimate CMI.
if len(feat_to_encourage) > 0:
scores_feat_to_encourage = torch.stack([mixed_cmi_model(inputs[:,feat], outputs, targets, feature_is_categorical=False, target_is_categorical=False) for feat in feat_to_encourage], 0)
else:
scores_feat_to_encourage = torch.tensor([]).float()
if len(feat_to_discourage) > 0:
scores_feat_to_discourage = torch.stack([mixed_cmi_model(inputs[:,feat], outputs, targets, feature_is_categorical=False, target_is_categorical=False) for feat in feat_to_discourage], 0)
else:
scores_feat_to_discourage = torch.tensor([]).float()
# Transform to [0,1].
# NOTE: Even though in theory CMI >= 0, in practice our estimates
# can be smaler than zero. Therefore, we need to avoid passing
# values < 0 to the sqrt.
# In analogy to Straight-Through Estimators (STEs) we apply our
# transformation only to inputs > 0 and use the identity transformation
# for inputs <= 0.
scores_feat_to_encourage[scores_feat_to_encourage > 0] = torch.sqrt(1 - torch.exp(-2*scores_feat_to_encourage[scores_feat_to_encourage > 0]))
scores_feat_to_discourage[scores_feat_to_discourage > 0] = torch.sqrt(1 - torch.exp(-2*scores_feat_to_discourage[scores_feat_to_discourage > 0]))
# Numerical stability: If there are no features to en- or discourage,
# we can explicitly set their contribution to 0.
if len(feat_to_encourage) == 0:
scores_feat_to_encourage = torch.tensor(0)
if len(feat_to_discourage) == 0:
scores_feat_to_discourage = torch.tensor(0)
# Feature steering.
if feat_steering_config["attrib_mode"] == "cmi":
# The transformed CMI estimates can be negative. Since applying
# L1 / L2 norm would emphasize them, we only apply the norm to
# transformed estimates >= 0.
# To all other values, similarly to above the identity transformation
# is applied (analogous to Straight-Through Estimators, keeps
# gradients).
# In practice, this means that for L1 norm we perform the identity
# transformation.
if feat_steering_config["steering_mode"] == "loss_l2":
scores_feat_to_encourage[scores_feat_to_encourage >= 0] = torch.square(scores_feat_to_encourage[scores_feat_to_encourage >= 0])
scores_feat_to_discourage[scores_feat_to_discourage >= 0] = torch.square(scores_feat_to_discourage[scores_feat_to_discourage >= 0])
return feat_steering_config["lambda"] * (torch.sum(scores_feat_to_discourage) - torch.sum(scores_feat_to_encourage)) / inputs.size()[0] # Average over Batch
# Apply weight factor lambda.
if feat_steering_config["lambda"] == 0:
return torch.tensor(0.0)
elif feat_steering_config["steering_mode"] == "loss_l1":
feat_steering_loss = feat_steering_config["lambda"] * (torch.sum(torch.abs(scores_feat_to_discourage)) - torch.sum(torch.abs(scores_feat_to_encourage)))
elif feat_steering_config["steering_mode"] == "loss_l2":
feat_steering_loss = feat_steering_config["lambda"] * (torch.sum(torch.square(scores_feat_to_discourage)) - torch.sum(torch.square(scores_feat_to_encourage)))
return feat_steering_loss / inputs.size()[0] # Average over Batch
def loss(self, inputs:torch.Tensor, targets:torch.Tensor, outputs:torch.Tensor, feat_steering_config:dict=None) -> torch.Tensor:
"""Loss function of the network.
The loss of the network is composed of the standard maximum-likelihood
loss and the feature steering portion of the loss.
feat_steering_config specifies how the feature steering portion of the
loss shall be calculated.
Args:
inputs (torch.Tensor): Inputs to the network.
targets (torch.Tensor): Targets for the specified inputs.
outputs (torch.Tensor): Outputs generated by the network for the
specified inputs.
feat_steering_config (dict, optional): Configuration how feature steering
shall be performed. If None, no feature steering is performed.
Defaults to None.
Raises:
ValueError: Feature steering portion of the loss is nan. Therefore,
this portion of the loss has no reasonable gradient, which would
cause problems in later steps of the training process.
Returns:
torch.Tensor: Total loss per sample.
"""
# For MSE make sure that outputs is a 1D tensor. That is, we need to
# prevent tensors of shape torch.Size([batch_size, 1]).
if len(outputs.size()) > 1:
outputs = outputs.squeeze(axis=1)
# Compute default loss.
loss_func = nn.MSELoss()
loss = loss_func(outputs, targets)
# No feature steering if in evaluation mode or explicitly specified.
if not self.training or feat_steering_config["steering_mode"] == "none":
return loss
else:
feat_steering_loss = self.feat_steering_loss(inputs, targets, outputs, feat_steering_config=feat_steering_config)
if feat_steering_loss.isnan():
raise ValueError("The feature steering loss of your model is nan.\
Thus, no reasonable gradient can be computed.")
return loss + feat_steering_loss
def train(self, train_dataloader: DataLoader, feat_steering_config:dict, epochs:int=90, learning_rate:float=0.01):
"""Performs training of the RegressionNetwork.
The network is trained using PyTorch's AdamW optimizer with default
parameters except for the learning rate.
Args:
train_dataloader (DataLoader): Contains the samples used for training.
feat_steering_config (dict): Configuration how feature steering
shall be performed. If None, no feature steering is performed.
epochs (int, optional): Number of epochs the network is trained.
Defaults to 90.
learning_rate (float, optional): Learning rate. Defaults to 0.01.
Raises:
ValueError: If the output of the network contains nan for a given input,
no reasonable loss can be computed.
ValueError: If the loss of the network is infinity / nan, no reasonable
gradient can be computed for optimization.
"""
optimizer = torch.optim.AdamW(self.layers.parameters(), lr=learning_rate)
for epoch in range(epochs):
epoch_loss = 0.0
for inputs, targets in train_dataloader:
# Pass data to GPU / CPU if necessary.
inputs, targets = inputs.to(self.device), targets.to(self.device)
# Zero the gradients.
optimizer.zero_grad()
# Perform forward pass.
outputs = self(inputs)
if outputs.isnan().any():
raise ValueError("The output of the model contains nan. Thus, no \
reasonable loss can be computed!")
# Calculate loss.
loss = self.loss(inputs, targets, outputs, feat_steering_config=feat_steering_config)
# Perform backward pass and modify weights accordingly.
if loss == torch.inf or loss.isnan():
raise ValueError("The loss of your model is inf / nan. Thus, no reasonable gradient can be computed!")
loss.backward()
optimizer.step()
# Print statistics.
epoch_loss += loss.item()
print("Loss (per sample) after epoch " + str(epoch+1) + ": " + str(epoch_loss / len(train_dataloader)))