From d0b1c8a86e703fff0cef176fcee8bd44f526f38b Mon Sep 17 00:00:00 2001 From: Zhong Yi Wan Date: Wed, 21 Feb 2024 19:13:18 -0800 Subject: [PATCH] Code update PiperOrigin-RevId: 609201955 --- swirl_dynamics/lib/diffusion/samplers.py | 189 +++++++++--------- swirl_dynamics/lib/diffusion/samplers_test.py | 19 +- 2 files changed, 99 insertions(+), 109 deletions(-) diff --git a/swirl_dynamics/lib/diffusion/samplers.py b/swirl_dynamics/lib/diffusion/samplers.py index aefb335..83a91d1 100644 --- a/swirl_dynamics/lib/diffusion/samplers.py +++ b/swirl_dynamics/lib/diffusion/samplers.py @@ -139,15 +139,26 @@ class Sampler: input_shape: The tensor shape of a sample (excluding any batch dimensions). scheme: The diffusion scheme which contains the scale and noise schedules. denoise_fn: A function to remove noise from input data. Must handle batched - inputs and noise levels. + inputs, noise levels and conditions. + tspan: Full diffusion time steps for iterative denoising, decreasing from 1 + to (approximately) 0. guidance_transforms: An optional sequence of guidance transforms that modifies the denoising function in a post-process fashion. + apply_denoise_at_end: If `True`, applies the denoise function another time + to the terminal states, which are typically at a small but non-zero noise + level. + return_full_paths: If `True`, the output of `.generate()` and `.denoise()` + will contain the complete sampling paths. Otherwise only the terminal + states are returned. """ input_shape: tuple[int, ...] scheme: diffusion.Diffusion denoise_fn: DenoiseFn - guidance_transforms: Sequence[guidance.Transform] + tspan: Array + guidance_transforms: Sequence[guidance.Transform] = () + apply_denoise_at_end: bool = True + return_full_paths: bool = False def generate( self, @@ -156,13 +167,13 @@ def generate( cond: ArrayMapping | None = None, guidance_inputs: ArrayMapping | None = None, ) -> Array: - """Generates a batch of diffusion samples. + """Generates a batch of diffusion samples from scratch. Args: num_samples: The number of samples to generate in a single batch. rng: The base rng for the generation process. cond: Explicit conditioning inputs for the denoising function. These - should be provided without any batch dimensions (one should be added + should be provided **without** batch dimensions (one should be added inside this function based on `num_samples`). guidance_inputs: Inputs used to construct the guided denoising function. These also should in principle not include a batch dimension. @@ -170,6 +181,56 @@ def generate( Returns: The generated samples. """ + if self.tspan is None or self.tspan.ndim != 1: + raise ValueError("`tspan` must be a 1-d array.") + + init_rng, denoise_rng = jax.random.split(rng) + x_shape = (num_samples,) + self.input_shape + x1 = jax.random.normal(init_rng, x_shape) + x1 *= self.scheme.sigma(self.tspan[0]) * self.scheme.scale(self.tspan[0]) + + if cond is not None: + cond = jax.tree_map(lambda x: jnp.stack([x] * num_samples, axis=0), cond) + + denoised = self.denoise(x1, denoise_rng, self.tspan, cond, guidance_inputs) + + samples = denoised[-1] if self.return_full_paths else denoised + if self.apply_denoise_at_end: + denoise_fn = self.get_guided_denoise_fn(guidance_inputs=guidance_inputs) + samples = denoise_fn( + jnp.divide(samples, self.scheme.scale(self.tspan[-1])), + self.scheme.sigma(self.tspan[-1]), + cond, + ) + if self.return_full_paths: + denoised = jnp.concatenate([denoised, samples[None]], axis=0) + + return denoised if self.return_full_paths else samples + + def denoise( + self, + noisy: Array, + rng: Array, + tspan: Array, + cond: ArrayMapping | None, + guidance_inputs: ArrayMapping | None, + ) -> Array: + """Applies iterative denoising to given noisy states. + + Args: + noisy: A batch of noisy states (all at the same noise level). Can be fully + noisy or partially denoised. + rng: Base Jax rng for denoising. + tspan: A decreasing sequence of diffusion time steps within the interval + [1, 0). The first element aligns with the time step of the `noisy` + input. + cond: (Optional) Conditioning inputs for the denoise function. The batch + dimension should match that of `noisy`. + guidance_inputs: Inputs for constructing the guided denoising function. + + Returns: + The denoised output. + """ raise NotImplementedError def get_guided_denoise_fn( @@ -187,63 +248,26 @@ class OdeSampler(Sampler): """Draw samples by solving an probabilistic flow ODE. Attributes: - input_shape: The tensor shape of a sample (excluding any batch dimensions). - scheme: The diffusion scheme which contains the scale and noise schedules. - denoise_fn: A function to remove noise from input data. Must handle batched - inputs and noise levels. integrator: The ODE solver for solving the sampling ODE. - tspan: The time steps for the ODE solver (decreasing typically from 1 to 0). - guidance_transforms: An optional sequence of guidance transforms that - modifies the denoising function in a post-process fashion. - apply_denoise_at_end: Whether to apply the denoise function for another time - to the terminal state. - return_full_paths: If `True`, the output will contain the complete sampling - paths with axis 0 corresponding to diffusion times specified by `tspan`. """ - input_shape: tuple[int, ...] - scheme: diffusion.Diffusion - denoise_fn: DenoiseFn - integrator: ode.OdeSolver - tspan: Array - guidance_transforms: Sequence[guidance.Transform] - apply_denoise_at_end: bool = True - return_full_paths: bool = False + integrator: ode.OdeSolver = ode.HeunsMethod() - def generate( + def denoise( self, - num_samples: int, + noisy: Array, rng: Array, + tspan: Array, cond: ArrayMapping | None = None, guidance_inputs: ArrayMapping | None = None, ) -> Array: - """Generate a batch of samples by solving the sampling ODE.""" - if self.tspan.ndim != 1: - raise ValueError("`tspan` must be a 1-d array.") - - x_shape = (num_samples,) + self.input_shape - t0, t1 = self.tspan[0], self.tspan[-1] - x1 = jax.random.normal(rng, x_shape) - x1 *= self.scheme.sigma(t0) * self.scheme.scale(t0) - - if cond is not None: - rep_fn = lambda x: jnp.tile(x[None], (num_samples,) + (1,) * x.ndim) - cond = jax.tree_map(rep_fn, cond) - + """Applies iterative denoising to given noisy states.""" + del rng params = dict(cond=cond, guidance_inputs=guidance_inputs) - # Output of integrator must have time at axis 0. - paths = self.integrator(self.dynamics, x1, self.tspan, params) - - if self.apply_denoise_at_end: - denoise_fn = self.get_guided_denoise_fn(guidance_inputs=guidance_inputs) - final = denoise_fn( - jnp.divide(paths[-1], self.scheme.scale(t1)), - self.scheme.sigma(t1), - cond, - ) - paths = jnp.concatenate([paths, final[None]], axis=0) - - return paths if self.return_full_paths else paths[-1] + # Currently all ODE integrators return full paths. The lead axis should + # always be time. + denoised = self.integrator(self.dynamics, noisy, tspan, params) + return denoised if self.return_full_paths else denoised[-1] @property def dynamics(self) -> ode.OdeDynamics: @@ -280,67 +304,34 @@ class SdeSampler(Sampler): """Draws samples by solving an SDE. Attributes: - input_shape: The tensor shape of a sample (excluding any batch dimensions). - scheme: The diffusion scheme which contains the scale and noise schedules. - denoise_fn: A function to remove noise from input data. Must handle batched - inputs and noise levels. integrator: The SDE solver for solving the sampling SDE. - tspan: The time steps for the SDE solver (decreasing typically from 1 to - 0). - guidance_transforms: An optional sequence of guidance transforms that - modifies the denoising function in a post-process fashion. - apply_denoise_at_end: Whether to apply the denoise function for another time - to the terminal state. - return_full_paths: If `True`, the output will contain the complete sampling - paths with axis 0 corresponding to diffusion times specified by `tspan`. """ - input_shape: tuple[int, ...] - scheme: diffusion.Diffusion - denoise_fn: DenoiseFn - integrator: sde.SdeSolver - tspan: Array - guidance_transforms: Sequence[guidance.Transform] - apply_denoise_at_end: bool = True - return_full_paths: bool = False + integrator: sde.SdeSolver = sde.EulerMaruyama(iter_type="scan") - def generate( + def denoise( self, - num_samples: int, + noisy: Array, rng: Array, + tspan: Array, cond: ArrayMapping | None = None, guidance_inputs: ArrayMapping | None = None, ) -> Array: - """Generate a batch of samples by solving an SDE.""" - if self.tspan.ndim != 1: - raise ValueError("`tspan` must be a 1-d array.") - - init_rng, solver_rng = jax.random.split(rng) - x_shape = (num_samples,) + self.input_shape - t0, t1 = self.tspan[0], self.tspan[-1] - x1 = jax.random.normal(init_rng, x_shape) - x1 *= self.scheme.sigma(t0) * self.scheme.scale(t0) - - if cond is not None: - rep_fn = lambda x: jnp.tile(x[None], (num_samples,) + (1,) * x.ndim) - cond = jax.tree_map(rep_fn, cond) + """Applies iterative denoising to given noisy states.""" + if self.integrator.terminal_only and self.return_full_paths: + raise ValueError( + f"Integrator type `{type(self.integrator)}` does not support" + " returning full paths." + ) params = dict( drift=dict(guidance_inputs=guidance_inputs, cond=cond), diffusion={} ) - # Output of integrator must have time at axis 0. - paths = self.integrator(self.dynamics, x1, self.tspan, solver_rng, params) - - if self.apply_denoise_at_end: - denoise_fn = self.get_guided_denoise_fn(guidance_inputs=guidance_inputs) - final = denoise_fn( - jnp.divide(paths[-1], self.scheme.scale(t1)), - self.scheme.sigma(t1), - cond, - ) - paths = jnp.concatenate([paths, final[None]], axis=0) - - return paths if self.return_full_paths else paths[-1] + denoised = self.integrator(self.dynamics, noisy, tspan, rng, params) + # SDE solvers may return either the full paths or the terminal state only. + # If the former, the lead axis should be time. + samples = denoised if self.integrator.terminal_only else denoised[-1] + return denoised if self.return_full_paths else samples @property def dynamics(self) -> sde.SdeDynamics: diff --git a/swirl_dynamics/lib/diffusion/samplers_test.py b/swirl_dynamics/lib/diffusion/samplers_test.py index 8701cfe..a790950 100644 --- a/swirl_dynamics/lib/diffusion/samplers_test.py +++ b/swirl_dynamics/lib/diffusion/samplers_test.py @@ -78,12 +78,10 @@ class SamplersTest(parameterized.TestCase): @parameterized.parameters( (samplers.OdeSampler, ode.ExplicitEuler(), True), (samplers.OdeSampler, ode.ExplicitEuler(), False), - (samplers.SdeSampler, sde.EulerMaruyama(), True), - (samplers.SdeSampler, sde.EulerMaruyama(), False), + (samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan"), True), + (samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan"), False), ) - def test_ode_sampler_output_shape( - self, sampler, solver, apply_denoise_at_end - ): + def test_sampler_output_shape(self, sampler, solver, apply_denoise_at_end): input_shape = (5, 1) num_samples = 4 num_steps = 8 @@ -109,7 +107,7 @@ def test_ode_sampler_output_shape( @parameterized.parameters( (samplers.OdeSampler, ode.ExplicitEuler()), - (samplers.SdeSampler, sde.EulerMaruyama()), + (samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan")), ) def test_unet_denoiser(self, sampler, solver): input_shape = (64, 64, 3) @@ -151,7 +149,8 @@ def test_unet_denoiser(self, sampler, solver): @parameterized.parameters( (samplers.OdeSampler, ode.HeunsMethod()), - (samplers.SdeSampler, sde.EulerMaruyama()), + (samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan")), + (samplers.SdeSampler, sde.EulerMaruyama(iter_type="loop")), ) def test_output_shape_with_guidance(self, sampler, solver): input_shape = (5, 1) @@ -167,8 +166,8 @@ def test_output_shape_with_guidance(self, sampler, solver): denoise_fn=lambda x, t, cond: x * t, guidance_transforms=(TestTransform(),), ) - samples = jax.jit(sampler.generate, static_argnums=0)( - num_samples=num_samples, + generate_fn = jax.jit(functools.partial(sampler.generate, num_samples)) + samples = generate_fn( rng=jax.random.PRNGKey(0), guidance_inputs={"const": jnp.ones(input_shape)}, ) @@ -176,7 +175,7 @@ def test_output_shape_with_guidance(self, sampler, solver): @parameterized.parameters( (samplers.OdeSampler, ode.HeunsMethod()), - (samplers.SdeSampler, sde.EulerMaruyama()), + (samplers.SdeSampler, sde.EulerMaruyama(iter_type="scan")), ) def test_output_shape_with_cond(self, sampler, solver): input_shape = (5, 1)