From 61afc04f8309eef0f8eebbc819a23c1c0ba88692 Mon Sep 17 00:00:00 2001 From: Zhong Yi Wan Date: Thu, 16 Nov 2023 09:44:25 -0800 Subject: [PATCH] Code update PiperOrigin-RevId: 583075541 --- swirl_dynamics/lib/diffusion/guidance.py | 73 ++++++--- swirl_dynamics/lib/diffusion/guidance_test.py | 14 +- swirl_dynamics/lib/diffusion/samplers.py | 147 +++++++++++++----- swirl_dynamics/lib/diffusion/samplers_test.py | 49 ++++-- 4 files changed, 199 insertions(+), 84 deletions(-) diff --git a/swirl_dynamics/lib/diffusion/guidance.py b/swirl_dynamics/lib/diffusion/guidance.py index 469b580..579700b 100644 --- a/swirl_dynamics/lib/diffusion/guidance.py +++ b/swirl_dynamics/lib/diffusion/guidance.py @@ -12,64 +12,89 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Modules for a posteriori (post-processing) guidance.""" +"""Modules for guidance transforms for denoising functions.""" -from collections.abc import Callable -from typing import Protocol +from collections.abc import Callable, Mapping +from typing import Any, Protocol import flax import jax import jax.numpy as jnp Array = jax.Array -DenoiseFn = Callable[[Array, Array], Array] +PyTree = Any +Cond = Mapping[str, PyTree] | None +DenoiseFn = Callable[[Array, Array, Cond], Array] -class Guidance(Protocol): +class Transform(Protocol): + """Transforms a denoising function to follow some guidance. + + One may think of these transforms as instances of Python decorators, + specifically made for denoising functions. Each transform takes a base + denoising function and extends it (often using some additional data) to build + a new denoising function with the same interface. + """ def __call__( - self, denoise_fn: DenoiseFn, guidance_input: Array | None + self, denoise_fn: DenoiseFn, guidance_inputs: Mapping[str, Array] ) -> DenoiseFn: - """Constructs guided denoise function.""" + """Constructs a guided denoising function. + + Args: + denoise_fn: The base denoising function. + guidance_inputs: A dictionary containing inputs used to construct the + guided denoising function. Note that all transforms *share the same + input dict*, therefore all transforms should use different fields from + this dict (unless absolutely intended) to avoid potential name clashes. + + Returns: + The guided denoising function. + """ ... @flax.struct.dataclass class InfillFromSlices: - """N-dimensional infilling guided by slices. + """N-dimensional infilling guided by known values on slices. Example usage:: - # 2D infill given every 8th pixel along both dimensions - # (assuming that the lead dimension is for batch) + # 2D infill given every 8th pixel along both dimensions (assuming that the + # lead dimension is for batch). slices = tuple(slice(None), slice(None, None, 8), slice(None, None, 8)) - sr_guidance = InfillFromSlices(slices) + sr_guidance = InfillFromSlices(slices, guide_strength=0.1) - # Post process a trained denoiser function via function composition - # `guidance_input` must have compatible shape s.t. - # `image[slices] = guidance_input` would not result in errors - guided_denoiser = sr_guidance(denoiser, guidance_input=jnp.array(0.0)) + # Post-process a trained denoiser function via function composition. + # The `observed_slices` arg must have compatible shape such that + # `image[slices] = observed_slices` would not raise errors. + guided_denoiser = sr_guidance(denoiser, {"observed_slices": jnp.array(0.0)}) # Run guided denoiser the same way as a normal one - denoised = guided_denoiser(noised, sigma=jnp.array(0.1)) + denoised = guided_denoiser(noised, sigma=jnp.array(0.1), cond=None) Attributes: - slices: The slices of the input to guide denoising (i.e. the rest is - infilled). The `guidance_input` provided when calling this method must be - compatible with these slices. + slices: The slices of the input to guide denoising (i.e. the rest is being + infilled). guide_strength: The strength of the guidance relative to the raw denoiser. + It will be rescaled based on the fraction of values being conditioned. """ slices: tuple[slice, ...] guide_strength: float = 0.5 - def __call__(self, denoise_fn: DenoiseFn, guidance_input: Array) -> DenoiseFn: + def __call__( + self, denoise_fn: DenoiseFn, guidance_inputs: Mapping[str, Array] + ) -> DenoiseFn: """Constructs denoise function guided by values on specified slices.""" - def _guided_denoise(x: Array, sigma: Array) -> Array: + def _guided_denoise(x: Array, sigma: Array, cond: Cond = None) -> Array: def constraint(xt: Array) -> tuple[Array, Array]: - denoised = denoise_fn(xt, sigma) - return jnp.sum((denoised[self.slices] - guidance_input) ** 2), denoised + denoised = denoise_fn(xt, sigma, cond) + error = jnp.sum( + (denoised[self.slices] - guidance_inputs["observed_slices"]) ** 2 + ) + return error, denoised constraint_grad, denoised = jax.grad(constraint, has_aux=True)(x) # normalize wrt the fraction of values being conditioned @@ -78,6 +103,6 @@ def constraint(xt: Array) -> tuple[Array, Array]: ) guide_strength = self.guide_strength / cond_fraction denoised -= guide_strength * constraint_grad - return denoised.at[self.slices].set(guidance_input) + return denoised.at[self.slices].set(guidance_inputs["observed_slices"]) return _guided_denoise diff --git a/swirl_dynamics/lib/diffusion/guidance_test.py b/swirl_dynamics/lib/diffusion/guidance_test.py index 9c13413..a20abb2 100644 --- a/swirl_dynamics/lib/diffusion/guidance_test.py +++ b/swirl_dynamics/lib/diffusion/guidance_test.py @@ -19,7 +19,7 @@ from swirl_dynamics.lib.diffusion import guidance -class GuidanceTest(parameterized.TestCase): +class GuidanceTransformsTest(parameterized.TestCase): @parameterized.parameters( {"test_dim": (4, 16), "ds_ratios": (None, 8), "guide_shape": (4, 2)}, @@ -31,13 +31,17 @@ class GuidanceTest(parameterized.TestCase): ) def test_super_resolution(self, test_dim, ds_ratios, guide_shape): superresolve = guidance.InfillFromSlices( - slices=tuple(slice(None, None, r) for r in ds_ratios) + slices=tuple(slice(None, None, r) for r in ds_ratios), ) - dummy_denoiser = lambda x, sigma: jnp.ones_like(x) + + def _dummy_denoiser(x, sigma, cond=None): + del sigma, cond + return jnp.ones_like(x) + guided_denoiser = superresolve( - dummy_denoiser, guidance_input=jnp.array(0.0) + _dummy_denoiser, {"observed_slices": jnp.array(0.0)} ) - denoised = guided_denoiser(jnp.ones(test_dim), jnp.array(0.1)) + denoised = guided_denoiser(jnp.ones(test_dim), jnp.array(0.1), None) guided_elements = denoised[superresolve.slices] self.assertEqual(denoised.shape, test_dim) self.assertEqual(guided_elements.shape, guide_shape) diff --git a/swirl_dynamics/lib/diffusion/samplers.py b/swirl_dynamics/lib/diffusion/samplers.py index d3d353e..234d157 100644 --- a/swirl_dynamics/lib/diffusion/samplers.py +++ b/swirl_dynamics/lib/diffusion/samplers.py @@ -14,7 +14,7 @@ """Diffusion samplers.""" -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from typing import Any, Protocol import flax @@ -27,7 +27,8 @@ Array = jax.Array PyTree = Any -DenoiseFn = Callable[[Array, Array], Array] +Cond = Mapping[str, PyTree] | None +DenoiseFn = Callable[[Array, Array, Cond], Array] ScoreFn = DenoiseFn @@ -46,11 +47,11 @@ def denoiser2score( ) -> ScoreFn: """Converts a denoiser to the corresponding score function.""" - def _score(x: Array, sigma: Array) -> Array: + def _score(x: Array, sigma: Array, cond: Cond = None) -> Array: # reference: eq. 74 in Karras et al. (https://arxiv.org/abs/2206.00364). scale = scheme.scale(scheme.sigma.inverse(sigma)) x_hat = jnp.divide(x, scale) - target = denoise_fn(x_hat, sigma) + target = denoise_fn(x_hat, sigma, cond) return jnp.divide(target - x_hat, scale * jnp.square(sigma)) return _score @@ -125,13 +126,13 @@ class Sampler(Protocol): """Interface for diffusion samplers.""" def generate( - self, rng: Array, num_samples: int, **kwargs + self, num_samples: int, rng: Array, **kwargs ) -> tuple[Array, Any]: """Generate a specified number of diffusion samples. Args: - rng: The base rng for the generation process. num_samples: The number of samples to generate. + rng: The base rng for the generation process. **kwargs: Additional keyword arguments. Returns: @@ -141,6 +142,16 @@ def generate( ... +def _apply_guidance_transforms( + denoise_fn: DenoiseFn, + transforms: Sequence[guidance.Transform], + guidance_inputs: Mapping[str, PyTree], +) -> DenoiseFn: + for transform in transforms: + denoise_fn = transform(denoise_fn, guidance_inputs) + return denoise_fn + + @flax.struct.dataclass class OdeSampler: """Draw samples by solving an probabilistic flow ODE. @@ -150,10 +161,10 @@ class OdeSampler: integrator: The ODE solver to use. scheme: The diffusion scheme which contains the scale and noise schedules to follow. - denoise_fn: The denoise function; required to work on batched states and + denoise_fn: The denoising function; required to work on batched states and noise levels. - guidance_fn: An optional guidance function that modifies the denoise - funciton in a post-process fashion. + guidance_transforms: An optional sequence of guidance transforms that + modifies the denoising funciton in a post-process fashion. apply_denoise_at_end: Whether to apply the denoise function for another time to the terminal state. """ @@ -162,23 +173,37 @@ class OdeSampler: integrator: ode.OdeSolver scheme: diffusion.Diffusion denoise_fn: DenoiseFn - guidance_fn: guidance.Guidance | None = None + guidance_transforms: Sequence[guidance.Transform] = () apply_denoise_at_end: bool = True - def get_guided_denoise_fn(self, guidance_input: Any = None) -> DenoiseFn: - denoise_fn = self.denoise_fn - if self.guidance_fn is not None: - denoise_fn = self.guidance_fn(denoise_fn, guidance_input=guidance_input) - return denoise_fn - def generate( self, + num_samples: int, rng: Array, tspan: Array, - num_samples: int, - guidance_input: Any = None, + cond: Cond = None, + guidance_inputs: Mapping[str, Any] | None = None, ) -> tuple[Array, dict[str, Array]]: - """Generate samples by solving the sampling ODE.""" + """Generate samples by solving the sampling ODE. + + Args: + num_samples: The number of distinct samples to generate. + rng: The jax random seed to be used for sampling. + tspan: The time steps for integrating the ode. + cond: The (explicit) conditioning inputs, i.e. those to be directly passed + through the denoiser interface. These inputs should not come with a + batch dimension - one will be created based on the number of samples (by + repeating every leaf of the pytree) to generate. + guidance_inputs: The inputs to the (a posteriori) guidance transforms. + They will *not* be passed to the denoiser directly but rather used to + "construct" a new denoising function. These inputs should also not come + with a batch dimension but the exact shapes are handled inside the + guidance transforms. + + Returns: + A tuple of generated samples and auxiliary outputs. The latter currently + consists of the entire ode trajectory. + """ if tspan.ndim != 1: raise ValueError("`tspan` must be a 1-d array.") @@ -186,17 +211,24 @@ def generate( t0, t1 = tspan[0], tspan[-1] x1 = jax.random.normal(rng, x_shape) x1 *= self.scheme.sigma(t0) * self.scheme.scale(t0) - params = {"guidance_input": guidance_input} + if cond is not None: + rep_fn = lambda x: jnp.tile(x[None], (num_samples,) + (1,) * x.ndim) + cond = jax.tree_map(rep_fn, cond) + params = dict(cond=cond, guidance_inputs=guidance_inputs) trajectories = self.integrator(self.dynamics, x1, tspan, params) samples = trajectories[-1] if self.apply_denoise_at_end: - denoise_fn = self.get_guided_denoise_fn( - guidance_input=params["guidance_input"] + denoise_fn = _apply_guidance_transforms( + self.denoise_fn, + self.guidance_transforms, + guidance_inputs=guidance_inputs, ) samples = denoise_fn( - jnp.divide(samples, self.scheme.scale(t1)), self.scheme.sigma(t1) + jnp.divide(samples, self.scheme.scale(t1)), + self.scheme.sigma(t1), + cond, ) return samples, {"trajectories": trajectories} @@ -217,14 +249,16 @@ def dynamics(self) -> ode.OdeDynamics: def _dynamics(x: Array, t: Array, params: PyTree) -> Array: assert not t.ndim, "`t` must be a scalar." - denoise_fn = self.get_guided_denoise_fn( - guidance_input=params["guidance_input"] + denoise_fn = _apply_guidance_transforms( + self.denoise_fn, + self.guidance_transforms, + guidance_inputs=params["guidance_inputs"], ) s, sigma = self.scheme.scale(t), self.scheme.sigma(t) x_hat = jnp.divide(x, s) dlog_sigma_dt = dlog_dt(self.scheme.sigma)(t) dlog_s_dt = dlog_dt(self.scheme.scale)(t) - target = denoise_fn(x_hat, sigma) + target = denoise_fn(x_hat, sigma, params["cond"]) return (dlog_sigma_dt + dlog_s_dt) * x - dlog_sigma_dt * s * target return _dynamics @@ -238,23 +272,37 @@ class SdeSampler: integrator: sde.SdeSolver scheme: diffusion.Diffusion denoise_fn: DenoiseFn - guidance_fn: guidance.Guidance | None = None + guidance_transforms: Sequence[guidance.Transform] = () apply_denoise_at_end: bool = True - def get_guided_denoise_fn(self, guidance_input: Any = None) -> DenoiseFn: - denoise_fn = self.denoise_fn - if self.guidance_fn is not None: - denoise_fn = self.guidance_fn(denoise_fn, guidance_input=guidance_input) - return denoise_fn - def generate( self, + num_samples: int, rng: Array, tspan: Array, - num_samples: int, - guidance_input: Any = None, + cond: Cond = None, + guidance_inputs: Mapping[str, Any] | None = None, ) -> tuple[Array, dict[str, Array]]: - """Generate samples by solving an SDE.""" + """Generate samples by solving an SDE. + + Args: + num_samples: The number of distinct samples to generate. + rng: The jax random seed to be used for sampling. + tspan: The time steps for integrating the sde. + cond: The (explicit) conditioning inputs, i.e. those to be directly passed + through the denoiser interface. These inputs *should not come with a + batch dimension* - one will be created based on the number of samples + (by repeating every leaf of the pytree). + guidance_inputs: The inputs to the (a posteriori) guidance transforms. + They will *not* be passed to the denoiser directly but rather used to + "construct" a new denoising function. Like `cond`, these inputs should + not come with a batch dimension in principle, but the shape handling + logic may be customized inside the specific guidance transforms. + + Returns: + A tuple of generated samples and auxiliary outputs. The latter currently + consists of the entire sde trajectory. + """ if tspan.ndim != 1: raise ValueError("`tspan` must be a 1-d array.") @@ -263,15 +311,26 @@ def generate( t0, t1 = tspan[0], tspan[-1] x1 = jax.random.normal(init_rng, x_shape) x1 *= self.scheme.sigma(t0) * self.scheme.scale(t0) - params = dict(drift={"guidance_input": guidance_input}, diffusion={}) + if cond is not None: + rep_fn = lambda x: jnp.tile(x[None], (num_samples,) + (1,) * x.ndim) + cond = jax.tree_map(rep_fn, cond) + + params = dict( + drift=dict(guidance_inputs=guidance_inputs, cond=cond), diffusion={} + ) trajectories = self.integrator(self.dynamics, x1, tspan, solver_rng, params) samples = trajectories[-1] + if self.apply_denoise_at_end: - denoise_fn = self.get_guided_denoise_fn( - guidance_input=params["drift"]["guidance_input"] + denoise_fn = _apply_guidance_transforms( + self.denoise_fn, + self.guidance_transforms, + guidance_inputs=guidance_inputs, ) samples = denoise_fn( - jnp.divide(samples, self.scheme.scale(t1)), self.scheme.sigma(t1) + jnp.divide(samples, self.scheme.scale(t1)), + self.scheme.sigma(t1), + cond, ) return samples, {"trajectories": trajectories} @@ -298,15 +357,17 @@ def dynamics(self) -> sde.SdeDynamics: def _drift(x: Array, t: Array, params: PyTree) -> Array: assert not t.ndim, "`t` must be a scalar." - denoise_fn = self.get_guided_denoise_fn( - guidance_input=params["guidance_input"] + denoise_fn = _apply_guidance_transforms( + self.denoise_fn, + self.guidance_transforms, + guidance_inputs=params["guidance_inputs"], ) s, sigma = self.scheme.scale(t), self.scheme.sigma(t) x_hat = jnp.divide(x, s) dlog_sigma_dt = dlog_dt(self.scheme.sigma)(t) dlog_s_dt = dlog_dt(self.scheme.scale)(t) drift = (2 * dlog_sigma_dt + dlog_s_dt) * x - drift -= 2 * dlog_sigma_dt * s * denoise_fn(x_hat, sigma) + drift -= 2 * dlog_sigma_dt * s * denoise_fn(x_hat, sigma, params["cond"]) return drift def _diffusion(x: Array, t: Array, params: PyTree) -> Array: diff --git a/swirl_dynamics/lib/diffusion/samplers_test.py b/swirl_dynamics/lib/diffusion/samplers_test.py index 49a4ef2..22a4188 100644 --- a/swirl_dynamics/lib/diffusion/samplers_test.py +++ b/swirl_dynamics/lib/diffusion/samplers_test.py @@ -67,10 +67,10 @@ def test_edm_noise_decay(self): np.testing.assert_allclose(tspan, np.asarray(expected_tspan), atol=1e-6) -class TestGuidance: +class TestTransform: - def __call__(self, denoise_fn, guidance_input): - return lambda x, t: denoise_fn(x, t) + guidance_input + def __call__(self, denoise_fn, guidance_inputs): + return lambda x, t, cond: denoise_fn(x, t, cond) + guidance_inputs["const"] class SamplersTest(parameterized.TestCase): @@ -93,14 +93,14 @@ def test_ode_sampler_output_shape( input_shape=input_shape, integrator=solver, scheme=scheme, - denoise_fn=lambda x, t: x, + denoise_fn=lambda x, t, cond: x, apply_denoise_at_end=apply_denoise_at_end, ) - generate = jax.jit(sampler.generate, static_argnums=(2,)) + generate = jax.jit(sampler.generate, static_argnums=0) samples, aux = generate( + num_samples=num_samples, rng=jax.random.PRNGKey(0), tspan=samplers.exponential_noise_decay(scheme, num_steps), - num_samples=num_samples, ) self.assertEqual(samples.shape, (num_samples,) + input_shape) self.assertEqual( @@ -139,11 +139,11 @@ def test_unet_denoiser(self, sampler, solver): scheme=scheme, denoise_fn=denoise_fn, ) - generate = jax.jit(sampler.generate, static_argnums=(2,)) + generate = jax.jit(sampler.generate, static_argnums=0) samples, aux = generate( + num_samples=num_samples, rng=jax.random.PRNGKey(0), tspan=samplers.exponential_noise_decay(scheme, num_steps), - num_samples=num_samples, ) self.assertEqual(samples.shape, (num_samples,) + input_shape) self.assertEqual( @@ -164,15 +164,40 @@ def test_output_shape_with_guidance(self, sampler, solver): input_shape=input_shape, integrator=solver, scheme=scheme, - denoise_fn=lambda x, t: x * t, - guidance_fn=TestGuidance(), + denoise_fn=lambda x, t, cond: x * t, + guidance_transforms=(TestTransform(),), ) - generate = jax.jit(sampler.generate, static_argnums=(2,)) + generate = jax.jit(sampler.generate, static_argnums=0) samples, _ = generate( + num_samples=num_samples, rng=jax.random.PRNGKey(0), tspan=samplers.exponential_noise_decay(scheme, num_steps), + guidance_inputs={"const": jnp.ones(input_shape)}, + ) + self.assertEqual(samples.shape, (num_samples,) + input_shape) + + @parameterized.parameters( + (samplers.OdeSampler, ode.HeunsMethod()), + (samplers.SdeSampler, sde.EulerMaruyama()), + ) + def test_output_shape_with_cond(self, sampler, solver): + input_shape = (5, 1) + num_samples = 4 + num_steps = 8 + sigma_schedule = diffusion.tangent_noise_schedule() + scheme = diffusion.Diffusion.create_variance_exploding(sigma_schedule) + sampler = sampler( + input_shape=input_shape, + integrator=solver, + scheme=scheme, + denoise_fn=lambda x, t, cond: x * t + cond["bias"], + ) + generate = jax.jit(sampler.generate, static_argnums=0) + samples, _ = generate( num_samples=num_samples, - guidance_input=jnp.ones(input_shape), + rng=jax.random.PRNGKey(0), + tspan=samplers.exponential_noise_decay(scheme, num_steps), + cond={"bias": jnp.ones(input_shape)}, ) self.assertEqual(samples.shape, (num_samples,) + input_shape)