-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CDDIMInterpolantSampler.py
48 lines (43 loc) · 1.36 KB
/
CDDIMInterpolantSampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import tensorflow as tf
from .CBasicInterpolantSampler import CBasicInterpolantSampler
from .CDDIMSamplingAlgorithm import CDDIMSamplingAlgorithm
class CDDIMInterpolantSampler(CBasicInterpolantSampler):
def __init__(
self, interpolant,
stochasticity, noiseProvider, schedule, steps, clipping, projectNoise
):
super().__init__(
interpolant=interpolant,
algorithm=CDDIMSamplingAlgorithm(
stochasticity=stochasticity,
noiseProvider=noiseProvider,
schedule=schedule,
steps=steps,
clipping=clipping,
projectNoise=projectNoise
)
)
self._schedule = schedule
return
def train(self, x0, x1, T, xT=None, model=None):
B = tf.shape(x0)[0]
tf.assert_equal(tf.shape(T), (B, 1))
tf.assert_equal(tf.shape(x0), tf.shape(x1))
T = self._schedule.to_discrete(T, lastStep=True)
# apply training procedure from interpolant
alpha_hat_t = self._schedule.parametersForT(T).alphaHat
tf.assert_equal(tf.shape(alpha_hat_t), (B, 1))
if (xT is not None):
REPLACE_NOISE = True
if REPLACE_NOISE:
x1 = xT
else:
x1 = None
pass
trainData = self._interpolant.train(x0=x0, x1=x1, T=alpha_hat_t, xT=xT)
return {
**trainData,
'alphaHat': alpha_hat_t,
'T': self._schedule.to_continuous(T),
}
# End of CDDIMInterpolantSampler