-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathUtilityBlocks.py
91 lines (62 loc) · 1.86 KB
/
UtilityBlocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from .ProgNet import ProgInertBlock
IO_STATE_IN = 0
IO_STATE_OUT = 1
I_FUNCTION = (lambda x : x)
"""
An inert ProgBlock that simply runs the input through python lambda functions.
Good for resizing or other non-learning ops.
"""
class ProgLambda(ProgInertBlock):
def __init__(self, lambdaMod):
super().__init__()
self.module = lambdaMod
def runBlock(self, x):
return self.module(x)
def runActivation(self, x):
return x
def getData(self):
data = dict()
data["type"] = "Lambda"
return data
"""
A convenience reshaping inert ProgBlock.
"""
class ProgReshape(ProgInertBlock):
def __init__(self, shape):
super().__init__()
self.sh = shape
def runBlock(self, x):
return torch.reshape(x, self.sh)
def runActivation(self, x):
return x
def getData(self):
data = dict()
data["type"] = "Reshape"
data["new_shape"] = str(self.sh)
return data
class ProgSkip(ProgInertBlock):
def __init__(self, lambdaActivation = I_FUNCTION, lambdaSkip = I_FUNCTION):
super().__init__()
self.skip = None
self.ioState = IO_STATE_IN
self.activation = lambdaActivation
self.skipFunction = lambdaSkip
def runBlock(self, x):
if self.ioState == IO_STATE_IN:
self.skip = x
self.ioState = IO_STATE_OUT
return x
else:
ret = self.skip
self.skip = None
self.ioState = IO_STATE_IN
return x + self.skipFunction(ret)
def runActivation(self, x):
return self.activation(x)
def getData(self):
data = dict()
data["type"] = "Skip"
data["id"] = str(id(self))
return data
#===============================================================================