-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
basic.py
62 lines (54 loc) · 1.49 KB
/
basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
from .IInterpolant import IInterpolant
from Utils.utils import CFakeObject
'''
Works by predicting direction towards x0 from xt.
'''
class CDirectionInterpolant(IInterpolant):
def interpolate(self, x0, x1, t, **kwargs):
# linear interpolation
return x0 + (x1 - x0) * t
def solve(self, x_hat, xt, t, **kwargs):
# x_hat is the direction towards x0
return CFakeObject(x0=xt + x_hat, x1=xt)
def train(self, x0, x1, T, xT=None):
xt = x1 if xT is None else xT
inputs = self.inference(xT=xt, T=T)
return {
'target': x0 - xt, # xt + x0 - xt == x0
'x0': x0,
'x1': xt,
**inputs,
}
def inference(self, xT, T):
B = tf.shape(xT)[0]
return {
'xT': xT,
'T': tf.zeros((B, 1), dtype=tf.float32), # not used
}
# End of CDirectionInterpolant
'''
Regardless of t and xt, always returns x0.
Same as single-pass restoration.
'''
class CConstantInterpolant(IInterpolant):
def interpolate(self, x0, x1, t, **kwargs):
tf.assert_equal(x0, x1)
return x0
def solve(self, x_hat, xt, t, **kwargs):
return CFakeObject(x0=x_hat, x1=x_hat)
def train(self, x0, x1, T, xT=None):
inputs = self.inference(xT=x1, T=T)
return {
'target': x0,
'x0': x0,
'x1': x1,
**inputs,
}
def inference(self, xT, T):
B = tf.shape(xT)[0]
return {
'xT': tf.zeros_like(xT), # not used
'T': tf.zeros((B, 1), dtype=tf.float32), # not used
}
# End of CConstantInterpolant