Skip to content

Commit

Permalink
Add sde handling to 3 order
Browse files Browse the repository at this point in the history
  • Loading branch information
StAlKeR7779 committed Aug 6, 2024
1 parent 2e8e6cc commit e73c056
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
12 changes: 11 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def multistep_dpm_solver_third_order_update(
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -884,6 +885,15 @@ def multistep_dpm_solver_third_order_update(
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h)**2 - 0.5)) * D2
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t

def index_for_timestep(self, timestep, schedule_timesteps=None):
Expand Down Expand Up @@ -990,7 +1000,7 @@ def step(
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)

if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
Expand Down
20 changes: 19 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def singlestep_dpm_solver_third_order_update(
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -830,6 +831,23 @@ def singlestep_dpm_solver_third_order_update(
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s2 * torch.exp(-h)) * sample
+ (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h)**2 - 0.5)) * D2
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t

def singlestep_dpm_solver_update(
Expand Down Expand Up @@ -891,7 +909,7 @@ def singlestep_dpm_solver_update(
elif order == 2:
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
elif order == 3:
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise)
else:
raise ValueError(f"Order must be 1, 2, 3, got {order}")

Expand Down

0 comments on commit e73c056

Please sign in to comment.