-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathMultiBlock.py
69 lines (52 loc) · 1.7 KB
/
MultiBlock.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch.nn as nn
from .ProgNet import ProgBlock, ProgInertBlock
class PassBlock(ProgInertBlock):
def __init__(self):
super().__init__()
def runBlock(self, x):
return x
def runActivation(self, x):
return x
def getData(self):
data = dict()
data["type"] = "Pass"
return data
class MultiBlock(ProgBlock):
def __init__(self, subBlocks):
super().__init__()
self.channels = nn.ModuleList(subBlocks)
def runBlock(self, x):
outs = []
self._checkInput(x)
for i, inp in enumerate(x):
b = self.channels[i]
outs.append(b.runBlock(inp))
return outs
def runLateral(self, j, x):
outs = []
self._checkInput(x)
for i, inp in enumerate(x):
b = self.channels[i]
if b.isLateralized():
outs.append(b.runLateral(j, inp))
else:
outs.append(None)
return outs
def runActivation(self, x):
outs = []
self._checkInput(x)
for i, inp in enumerate(x):
b = self.channels[i]
a = b.runActivation(inp)
outs.append(a)
return outs
def getData(self):
data = dict()
data["type"] = "Multi"
data["subblocks"] = [sb.getData() for sb in self.channels]
return data
def _checkInput(self, x):
if len(x) != len(self.channels):
errStr = "[Doric]: Input must be a python iterable with size equal to the number of channels in the MultiBlock (%d)." % len(self.channels)
raise ValueError(errStr)
#===============================================================================