-
Notifications
You must be signed in to change notification settings - Fork 1
/
EulerNew.py
196 lines (178 loc) · 8.58 KB
/
EulerNew.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from diffusers import EulerDiscreteScheduler
from torch import Tensor
import torch
from typing import Callable, List, Optional, Tuple, Union, Dict, Any, Literal
from diffusers.utils import BaseOutput
try:
# Try the old import path
from diffusers.utils import randn_tensor
except ImportError:
# If the old import path is not available, use the new import path
from diffusers.utils.torch_utils import randn_tensor
from diffusers.configuration_utils import ConfigMixin
from diffusers.schedulers.scheduling_utils import SchedulerMixin
class Output(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
class Euler(EulerDiscreteScheduler, SchedulerMixin, ConfigMixin):
history_d=0
momentum=0.95
momentum_hist=0.75
used_history_d=None
def init_hist_d(self,x:Tensor) -> Union[Literal[0], Tensor]:
# memorize delta momentum
if self.history_d == 0: self.used_history_d = 0
elif self.history_d == 'rand_init': self.used_history_d = x
elif self.history_d == 'rand_new': self.used_history_d = torch.randn_like(x)
else: raise ValueError(f'unknown momentum_hist_init: {self.history_d}')
# def add_noise(
# self,
# original_samples: torch.FloatTensor,
# noise: torch.FloatTensor,
# timesteps: torch.FloatTensor,
# ) -> torch.FloatTensor:
# # Make sure sigmas and timesteps have the same device and dtype as original_samples
# sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
# if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# # mps does not support float64
# schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
# timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
# else:
# schedule_timesteps = self.timesteps.to(original_samples.device)
# timesteps = timesteps.to(original_samples.device)
# step_indices = []
# print(212121221,timesteps)
# for t in timesteps:
# # Find the indices where schedule_timesteps is equal to t
# indices = (schedule_timesteps == t).nonzero()
# print(6666,indices)
# print(4444,schedule_timesteps,t)
# # Check if any indices were found
# if indices.numel() > 0:
# # Extract the first index as a scalar (assuming you want the first match)
# index = indices[0].item()
# step_indices.append(index)
# else:
# # Handle the case where no matching index was found
# step_indices.append(None)
# print(29292,step_indices)
# sigma = sigmas[step_indices].flatten()
# while len(sigma.shape) < len(original_samples.shape):
# sigma = sigma.unsqueeze(-1)
# noisy_samples = original_samples + noise * sigma
# return noisy_samples
def momentum_step(self, x:Tensor, d:Tensor, dt:Tensor):
hd=self.used_history_d
# correct current `d` with momentum
p = 1.0 - self.momentum
self.momentum_d = (1.0 - p) * d + p * hd
# Euler method with momentum
x = x + self.momentum_d * dt
# update momentum history
q = 1.0 - self.momentum_hist
if (isinstance(hd, int) and hd == 0):
hd = self.momentum_d
else:
hd = (1.0 - q) * hd + q * self.momentum_d
self.used_history_d=hd
return x
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
):
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`float`): current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
s_churn (`float`)
s_tmin (`float`)
s_tmax (`float`)
s_noise (`float`)
generator (`torch.Generator`, optional): Random number generator.
return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if not isinstance(self.used_history_d, torch.Tensor) and not isinstance(self.used_history_d, int):
self.init_hist_d(sample)
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if not self.is_scale_input_called:
logger.warning(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = self.momentum_step(sample,derivative,dt)
# print(111111,pred_original_sample.shape)
self._step_index+=1
if self._step_index==(len(self.sigmas)-1):
self.used_history_d=None
if not return_dict:
return (prev_sample,)
return Output(
prev_sample=prev_sample, pred_original_sample=pred_original_sample
)