Skip to content

Commit

Permalink
improve FrankenSolver
Browse files Browse the repository at this point in the history
It now takes a Scheduler factory instead of a Scheduler.
This lets the user potentially recreate the Scheduler on `rebuild`.

It also properly sets the device and dtype on rebuild,
and it has better typing.
  • Loading branch information
catwell committed Jul 19, 2024
1 parent 299217f commit daee772
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 11 deletions.
67 changes: 57 additions & 10 deletions src/refiners/foundationals/latent_diffusion/solvers/franken.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,42 @@
import dataclasses
from typing import Any, cast
from typing import Any, Callable, Protocol, TypeVar

from torch import Generator, Tensor
from torch import Generator, Tensor, device as Device, dtype as DType, float32

from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing

# Should be Tensor, but some Diffusers schedulers
# are improperly typed as only accepting `int`.
SchedulerTimestepT = Any


class SchedulerOutputLike(Protocol):
@property
def prev_sample(self) -> Tensor: ...


class SchedulerLike(Protocol):
timesteps: Tensor

@property
def init_noise_sigma(self) -> Tensor | float: ...

def set_timesteps(self, num_inference_steps: int, *args: Any, **kwargs: Any) -> None: ...

def scale_model_input(self, sample: Tensor, timestep: SchedulerTimestepT) -> Tensor: ...

def step(
self,
model_output: Tensor,
timestep: SchedulerTimestepT,
sample: Tensor,
*args: Any,
**kwargs: Any,
) -> SchedulerOutputLike | tuple[Any]: ...


TFrankenSolver = TypeVar("TFrankenSolver", bound="FrankenSolver")


class FrankenSolver(Solver):
"""Lets you use Diffusers Schedulers as Refiners Solvers.
Expand All @@ -14,7 +46,7 @@ class FrankenSolver(Solver):
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver
scheduler = EulerDiscreteScheduler(...)
solver = FrankenSolver(scheduler, num_inference_steps=steps)
solver = FrankenSolver(lambda: scheduler, num_inference_steps=steps)
"""

default_params = dataclasses.replace(
Expand All @@ -24,27 +56,40 @@ class FrankenSolver(Solver):

def __init__(
self,
diffusers_scheduler: Any,
get_diffusers_scheduler: Callable[[], SchedulerLike],
num_inference_steps: int,
first_inference_step: int = 0,
**kwargs: Any,
device: Device | str = "cpu",
dtype: DType = float32,
**kwargs: Any, # for typing, ignored
) -> None:
self.diffusers_scheduler = diffusers_scheduler
diffusers_scheduler.set_timesteps(num_inference_steps)
super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step)
self.get_diffusers_scheduler = get_diffusers_scheduler
self.diffusers_scheduler = self.get_diffusers_scheduler()
self.diffusers_scheduler.set_timesteps(num_inference_steps)
super().__init__(
num_inference_steps=num_inference_steps,
first_inference_step=first_inference_step,
device=device,
dtype=dtype,
)

def _generate_timesteps(self) -> Tensor:
return self.diffusers_scheduler.timesteps

def to(self: TFrankenSolver, device: Device | str | None = None, dtype: DType | None = None) -> TFrankenSolver:
return super().to(device=device, dtype=dtype) # type: ignore

def rebuild(
self,
num_inference_steps: int | None,
first_inference_step: int | None = None,
) -> "FrankenSolver":
return self.__class__(
diffusers_scheduler=self.diffusers_scheduler,
get_diffusers_scheduler=self.get_diffusers_scheduler,
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
device=self.device,
dtype=self.dtype,
)

def scale_model_input(self, x: Tensor, step: int) -> Tensor:
Expand All @@ -54,4 +99,6 @@ def scale_model_input(self, x: Tensor, step: int) -> Tensor:

def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep = self.timesteps[step]
return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample)
r = self.diffusers_scheduler.step(predicted_noise, timestep, x)
assert not isinstance(r, tuple), "scheduler returned a tuple"
return r.prev_sample
2 changes: 1 addition & 1 deletion tests/foundationals/latent_diffusion/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_franken_diffusers():
diffusers_scheduler.set_timesteps(30)

diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
solver = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30)
solver = FrankenSolver(lambda: diffusers_scheduler_2, num_inference_steps=30)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)

sample = randn(1, 4, 32, 32)
Expand Down

0 comments on commit daee772

Please sign in to comment.