-
Notifications
You must be signed in to change notification settings - Fork 1
/
mi_networks.py
76 lines (62 loc) · 2.05 KB
/
mi_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Module for networks used for computing MI.
"""
import numpy as np
import torch
import torch.nn as nn
class Permute(torch.nn.Module):
"""Module for permuting axes.
"""
def __init__(self, *perm):
"""
Args:
*perm: Permute axes.
"""
super().__init__()
self.perm = perm
def forward(self, input):
"""Permutes axes of tensor.
Args:
input: Input tensor.
Returns:
torch.Tensor: permuted tensor.
"""
return input.permute(*self.perm)
class MI1x1ConvNet(nn.Module):
"""Simple custorm 1x1 convnet.
"""
def __init__(self, n_input, n_units,):
"""
Args:
n_input: Number of input units.
n_units: Number of output units.
"""
super().__init__()
self.block_nonlinear = nn.Sequential(
nn.Conv2d(n_input, n_units, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(n_units),
nn.ReLU(),
nn.Conv2d(n_units, n_units, kernel_size=1, stride=1, padding=0, bias=True),
)
self.block_ln = nn.Sequential(
Permute(0, 2, 3, 1),
nn.LayerNorm(n_units),
Permute(0, 3, 1, 2)
)
self.linear_shortcut = nn.Conv2d(n_input, n_units, kernel_size=1,
stride=1, padding=0, bias=False)
# initialize shortcut to be like identity (if possible)
if n_units >= n_input:
eye_mask = np.zeros((n_units, n_input, 1, 1), dtype=np.uint8)
for i in range(n_input):
eye_mask[i, i, 0, 0] = 1
self.linear_shortcut.weight.data.uniform_(-0.01, 0.01)
self.linear_shortcut.weight.data.masked_fill_(torch.tensor(eye_mask), 1.)
def forward(self, x):
"""
Args:
x: Input tensor.
Returns:
torch.Tensor: network output.
"""
h = self.block_ln(self.block_nonlinear(x) + self.linear_shortcut(x))
return h