Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609201955
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Feb 22, 2024
1 parent 69860f7 commit d0b1c8a
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 109 deletions.
189 changes: 90 additions & 99 deletions swirl_dynamics/lib/diffusion/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,26 @@ class Sampler:
input_shape: The tensor shape of a sample (excluding any batch dimensions).
scheme: The diffusion scheme which contains the scale and noise schedules.
denoise_fn: A function to remove noise from input data. Must handle batched
inputs and noise levels.
inputs, noise levels and conditions.
tspan: Full diffusion time steps for iterative denoising, decreasing from 1
to (approximately) 0.
guidance_transforms: An optional sequence of guidance transforms that
modifies the denoising function in a post-process fashion.
apply_denoise_at_end: If `True`, applies the denoise function another time
to the terminal states, which are typically at a small but non-zero noise
level.
return_full_paths: If `True`, the output of `.generate()` and `.denoise()`
will contain the complete sampling paths. Otherwise only the terminal
states are returned.
"""

input_shape: tuple[int, ...]
scheme: diffusion.Diffusion
denoise_fn: DenoiseFn
guidance_transforms: Sequence[guidance.Transform]
tspan: Array
guidance_transforms: Sequence[guidance.Transform] = ()
apply_denoise_at_end: bool = True
return_full_paths: bool = False

def generate(
self,
Expand All @@ -156,20 +167,70 @@ def generate(
cond: ArrayMapping | None = None,
guidance_inputs: ArrayMapping | None = None,
) -> Array:
"""Generates a batch of diffusion samples.
"""Generates a batch of diffusion samples from scratch.
Args:
num_samples: The number of samples to generate in a single batch.
rng: The base rng for the generation process.
cond: Explicit conditioning inputs for the denoising function. These
should be provided without any batch dimensions (one should be added
should be provided **without** batch dimensions (one should be added
inside this function based on `num_samples`).
guidance_inputs: Inputs used to construct the guided denoising function.
These also should in principle not include a batch dimension.
Returns:
The generated samples.
"""
if self.tspan is None or self.tspan.ndim != 1:
raise ValueError("`tspan` must be a 1-d array.")

init_rng, denoise_rng = jax.random.split(rng)
x_shape = (num_samples,) + self.input_shape
x1 = jax.random.normal(init_rng, x_shape)
x1 *= self.scheme.sigma(self.tspan[0]) * self.scheme.scale(self.tspan[0])

if cond is not None:
cond = jax.tree_map(lambda x: jnp.stack([x] * num_samples, axis=0), cond)

denoised = self.denoise(x1, denoise_rng, self.tspan, cond, guidance_inputs)

samples = denoised[-1] if self.return_full_paths else denoised
if self.apply_denoise_at_end:
denoise_fn = self.get_guided_denoise_fn(guidance_inputs=guidance_inputs)
samples = denoise_fn(
jnp.divide(samples, self.scheme.scale(self.tspan[-1])),
self.scheme.sigma(self.tspan[-1]),
cond,
)
if self.return_full_paths:
denoised = jnp.concatenate([denoised, samples[None]], axis=0)

return denoised if self.return_full_paths else samples

def denoise(
self,
noisy: Array,
rng: Array,
tspan: Array,
cond: ArrayMapping | None,
guidance_inputs: ArrayMapping | None,
) -> Array:
"""Applies iterative denoising to given noisy states.
Args:
noisy: A batch of noisy states (all at the same noise level). Can be fully
noisy or partially denoised.
rng: Base Jax rng for denoising.
tspan: A decreasing sequence of diffusion time steps within the interval
[1, 0). The first element aligns with the time step of the `noisy`
input.
cond: (Optional) Conditioning inputs for the denoise function. The batch
dimension should match that of `noisy`.
guidance_inputs: Inputs for constructing the guided denoising function.
Returns:
The denoised output.
"""
raise NotImplementedError

def get_guided_denoise_fn(
Expand All @@ -187,63 +248,26 @@ class OdeSampler(Sampler):
"""Draw samples by solving an probabilistic flow ODE.
Attributes:
input_shape: The tensor shape of a sample (excluding any batch dimensions).
scheme: The diffusion scheme which contains the scale and noise schedules.
denoise_fn: A function to remove noise from input data. Must handle batched
inputs and noise levels.
integrator: The ODE solver for solving the sampling ODE.
tspan: The time steps for the ODE solver (decreasing typically from 1 to 0).
guidance_transforms: An optional sequence of guidance transforms that
modifies the denoising function in a post-process fashion.
apply_denoise_at_end: Whether to apply the denoise function for another time
to the terminal state.
return_full_paths: If `True`, the output will contain the complete sampling
paths with axis 0 corresponding to diffusion times specified by `tspan`.
"""

input_shape: tuple[int, ...]
scheme: diffusion.Diffusion
denoise_fn: DenoiseFn
integrator: ode.OdeSolver
tspan: Array
guidance_transforms: Sequence[guidance.Transform]
apply_denoise_at_end: bool = True
return_full_paths: bool = False
integrator: ode.OdeSolver = ode.HeunsMethod()

def generate(
def denoise(
self,
num_samples: int,
noisy: Array,
rng: Array,
tspan: Array,
cond: ArrayMapping | None = None,
guidance_inputs: ArrayMapping | None = None,
) -> Array:
"""Generate a batch of samples by solving the sampling ODE."""
if self.tspan.ndim != 1:
raise ValueError("`tspan` must be a 1-d array.")

x_shape = (num_samples,) + self.input_shape
t0, t1 = self.tspan[0], self.tspan[-1]
x1 = jax.random.normal(rng, x_shape)
x1 *= self.scheme.sigma(t0) * self.scheme.scale(t0)

if cond is not None:
rep_fn = lambda x: jnp.tile(x[None], (num_samples,) + (1,) * x.ndim)
cond = jax.tree_map(rep_fn, cond)

"""Applies iterative denoising to given noisy states."""
del rng
params = dict(cond=cond, guidance_inputs=guidance_inputs)
# Output of integrator must have time at axis 0.
paths = self.integrator(self.dynamics, x1, self.tspan, params)

if self.apply_denoise_at_end:
denoise_fn = self.get_guided_denoise_fn(guidance_inputs=guidance_inputs)
final = denoise_fn(
jnp.divide(paths[-1], self.scheme.scale(t1)),
self.scheme.sigma(t1),
cond,
)
paths = jnp.concatenate([paths, final[None]], axis=0)

return paths if self.return_full_paths else paths[-1]
# Currently all ODE integrators return full paths. The lead axis should
# always be time.
denoised = self.integrator(self.dynamics, noisy, tspan, params)
return denoised if self.return_full_paths else denoised[-1]

@property
def dynamics(self) -> ode.OdeDynamics:
Expand Down Expand Up @@ -280,67 +304,34 @@ class SdeSampler(Sampler):
"""Draws samples by solving an SDE.
Attributes:
input_shape: The tensor shape of a sample (excluding any batch dimensions).
scheme: The diffusion scheme which contains the scale and noise schedules.
denoise_fn: A function to remove noise from input data. Must handle batched
inputs and noise levels.
integrator: The SDE solver for solving the sampling SDE.
tspan: The time steps for the SDE solver (decreasing typically from 1 to
0).
guidance_transforms: An optional sequence of guidance transforms that
modifies the denoising function in a post-process fashion.
apply_denoise_at_end: Whether to apply the denoise function for another time
to the terminal state.
return_full_paths: If `True`, the output will contain the complete sampling
paths with axis 0 corresponding to diffusion times specified by `tspan`.
"""

input_shape: tuple[int, ...]
scheme: diffusion.Diffusion
denoise_fn: DenoiseFn
integrator: sde.SdeSolver
tspan: Array
guidance_transforms: Sequence[guidance.Transform]
apply_denoise_at_end: bool = True
return_full_paths: bool = False
integrator: sde.SdeSolver = sde.EulerMaruyama(iter_type="scan")

def generate(
def denoise(
self,
num_samples: int,
noisy: Array,
rng: Array,
tspan: Array,
cond: ArrayMapping | None = None,
guidance_inputs: ArrayMapping | None = None,
) -> Array:
"""Generate a batch of samples by solving an SDE."""
if self.tspan.ndim != 1:
raise ValueError("`tspan` must be a 1-d array.")

init_rng, solver_rng = jax.random.split(rng)
x_shape = (num_samples,) + self.input_shape
t0, t1 = self.tspan[0], self.tspan[-1]
x1 = jax.random.normal(init_rng, x_shape)
x1 *= self.scheme.sigma(t0) * self.scheme.scale(t0)

if cond is not None:
rep_fn = lambda x: jnp.tile(x[None], (num_samples,) + (1,) * x.ndim)
cond = jax.tree_map(rep_fn, cond)
"""Applies iterative denoising to given noisy states."""
if self.integrator.terminal_only and self.return_full_paths:
raise ValueError(
f"Integrator type `{type(self.integrator)}` does not support"
" returning full paths."
)

params = dict(
drift=dict(guidance_inputs=guidance_inputs, cond=cond), diffusion={}
)
# Output of integrator must have time at axis 0.
paths = self.integrator(self.dynamics, x1, self.tspan, solver_rng, params)

if self.apply_denoise_at_end:
denoise_fn = self.get_guided_denoise_fn(guidance_inputs=guidance_inputs)
final = denoise_fn(
jnp.divide(paths[-1], self.scheme.scale(t1)),
self.scheme.sigma(t1),
cond,
)
paths = jnp.concatenate([paths, final[None]], axis=0)

return paths if self.return_full_paths else paths[-1]
denoised = self.integrator(self.dynamics, noisy, tspan, rng, params)
# SDE solvers may return either the full paths or the terminal state only.
# If the former, the lead axis should be time.
samples = denoised if self.integrator.terminal_only else denoised[-1]
return denoised if self.return_full_paths else samples

@property
def dynamics(self) -> sde.SdeDynamics:
Expand Down
19 changes: 9 additions & 10 deletions swirl_dynamics/lib/diffusion/samplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,10 @@ class SamplersTest(parameterized.TestCase):
@parameterized.parameters(
(samplers.OdeSampler, ode.ExplicitEuler(), True),
(samplers.OdeSampler, ode.ExplicitEuler(), False),
(samplers.SdeSampler, sde.EulerMaruyama(), True),
(samplers.SdeSampler, sde.EulerMaruyama(), False),
(samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan"), True),
(samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan"), False),
)
def test_ode_sampler_output_shape(
self, sampler, solver, apply_denoise_at_end
):
def test_sampler_output_shape(self, sampler, solver, apply_denoise_at_end):
input_shape = (5, 1)
num_samples = 4
num_steps = 8
Expand All @@ -109,7 +107,7 @@ def test_ode_sampler_output_shape(

@parameterized.parameters(
(samplers.OdeSampler, ode.ExplicitEuler()),
(samplers.SdeSampler, sde.EulerMaruyama()),
(samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan")),
)
def test_unet_denoiser(self, sampler, solver):
input_shape = (64, 64, 3)
Expand Down Expand Up @@ -151,7 +149,8 @@ def test_unet_denoiser(self, sampler, solver):

@parameterized.parameters(
(samplers.OdeSampler, ode.HeunsMethod()),
(samplers.SdeSampler, sde.EulerMaruyama()),
(samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan")),
(samplers.SdeSampler, sde.EulerMaruyama(iter_type="loop")),
)
def test_output_shape_with_guidance(self, sampler, solver):
input_shape = (5, 1)
Expand All @@ -167,16 +166,16 @@ def test_output_shape_with_guidance(self, sampler, solver):
denoise_fn=lambda x, t, cond: x * t,
guidance_transforms=(TestTransform(),),
)
samples = jax.jit(sampler.generate, static_argnums=0)(
num_samples=num_samples,
generate_fn = jax.jit(functools.partial(sampler.generate, num_samples))
samples = generate_fn(
rng=jax.random.PRNGKey(0),
guidance_inputs={"const": jnp.ones(input_shape)},
)
self.assertEqual(samples.shape, (num_samples,) + input_shape)

@parameterized.parameters(
(samplers.OdeSampler, ode.HeunsMethod()),
(samplers.SdeSampler, sde.EulerMaruyama()),
(samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan")),
)
def test_output_shape_with_cond(self, sampler, solver):
input_shape = (5, 1)
Expand Down

0 comments on commit d0b1c8a

Please sign in to comment.