Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583075541
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Nov 16, 2023
1 parent a3399f0 commit 61afc04
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 84 deletions.
73 changes: 49 additions & 24 deletions swirl_dynamics/lib/diffusion/guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,89 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Modules for a posteriori (post-processing) guidance."""
"""Modules for guidance transforms for denoising functions."""

from collections.abc import Callable
from typing import Protocol
from collections.abc import Callable, Mapping
from typing import Any, Protocol

import flax
import jax
import jax.numpy as jnp

Array = jax.Array
DenoiseFn = Callable[[Array, Array], Array]
PyTree = Any
Cond = Mapping[str, PyTree] | None
DenoiseFn = Callable[[Array, Array, Cond], Array]


class Guidance(Protocol):
class Transform(Protocol):
"""Transforms a denoising function to follow some guidance.
One may think of these transforms as instances of Python decorators,
specifically made for denoising functions. Each transform takes a base
denoising function and extends it (often using some additional data) to build
a new denoising function with the same interface.
"""

def __call__(
self, denoise_fn: DenoiseFn, guidance_input: Array | None
self, denoise_fn: DenoiseFn, guidance_inputs: Mapping[str, Array]
) -> DenoiseFn:
"""Constructs guided denoise function."""
"""Constructs a guided denoising function.
Args:
denoise_fn: The base denoising function.
guidance_inputs: A dictionary containing inputs used to construct the
guided denoising function. Note that all transforms *share the same
input dict*, therefore all transforms should use different fields from
this dict (unless absolutely intended) to avoid potential name clashes.
Returns:
The guided denoising function.
"""
...


@flax.struct.dataclass
class InfillFromSlices:
"""N-dimensional infilling guided by slices.
"""N-dimensional infilling guided by known values on slices.
Example usage::
# 2D infill given every 8th pixel along both dimensions
# (assuming that the lead dimension is for batch)
# 2D infill given every 8th pixel along both dimensions (assuming that the
# lead dimension is for batch).
slices = tuple(slice(None), slice(None, None, 8), slice(None, None, 8))
sr_guidance = InfillFromSlices(slices)
sr_guidance = InfillFromSlices(slices, guide_strength=0.1)
# Post process a trained denoiser function via function composition
# `guidance_input` must have compatible shape s.t.
# `image[slices] = guidance_input` would not result in errors
guided_denoiser = sr_guidance(denoiser, guidance_input=jnp.array(0.0))
# Post-process a trained denoiser function via function composition.
# The `observed_slices` arg must have compatible shape such that
# `image[slices] = observed_slices` would not raise errors.
guided_denoiser = sr_guidance(denoiser, {"observed_slices": jnp.array(0.0)})
# Run guided denoiser the same way as a normal one
denoised = guided_denoiser(noised, sigma=jnp.array(0.1))
denoised = guided_denoiser(noised, sigma=jnp.array(0.1), cond=None)
Attributes:
slices: The slices of the input to guide denoising (i.e. the rest is
infilled). The `guidance_input` provided when calling this method must be
compatible with these slices.
slices: The slices of the input to guide denoising (i.e. the rest is being
infilled).
guide_strength: The strength of the guidance relative to the raw denoiser.
It will be rescaled based on the fraction of values being conditioned.
"""

slices: tuple[slice, ...]
guide_strength: float = 0.5

def __call__(self, denoise_fn: DenoiseFn, guidance_input: Array) -> DenoiseFn:
def __call__(
self, denoise_fn: DenoiseFn, guidance_inputs: Mapping[str, Array]
) -> DenoiseFn:
"""Constructs denoise function guided by values on specified slices."""

def _guided_denoise(x: Array, sigma: Array) -> Array:
def _guided_denoise(x: Array, sigma: Array, cond: Cond = None) -> Array:
def constraint(xt: Array) -> tuple[Array, Array]:
denoised = denoise_fn(xt, sigma)
return jnp.sum((denoised[self.slices] - guidance_input) ** 2), denoised
denoised = denoise_fn(xt, sigma, cond)
error = jnp.sum(
(denoised[self.slices] - guidance_inputs["observed_slices"]) ** 2
)
return error, denoised

constraint_grad, denoised = jax.grad(constraint, has_aux=True)(x)
# normalize wrt the fraction of values being conditioned
Expand All @@ -78,6 +103,6 @@ def constraint(xt: Array) -> tuple[Array, Array]:
)
guide_strength = self.guide_strength / cond_fraction
denoised -= guide_strength * constraint_grad
return denoised.at[self.slices].set(guidance_input)
return denoised.at[self.slices].set(guidance_inputs["observed_slices"])

return _guided_denoise
14 changes: 9 additions & 5 deletions swirl_dynamics/lib/diffusion/guidance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from swirl_dynamics.lib.diffusion import guidance


class GuidanceTest(parameterized.TestCase):
class GuidanceTransformsTest(parameterized.TestCase):

@parameterized.parameters(
{"test_dim": (4, 16), "ds_ratios": (None, 8), "guide_shape": (4, 2)},
Expand All @@ -31,13 +31,17 @@ class GuidanceTest(parameterized.TestCase):
)
def test_super_resolution(self, test_dim, ds_ratios, guide_shape):
superresolve = guidance.InfillFromSlices(
slices=tuple(slice(None, None, r) for r in ds_ratios)
slices=tuple(slice(None, None, r) for r in ds_ratios),
)
dummy_denoiser = lambda x, sigma: jnp.ones_like(x)

def _dummy_denoiser(x, sigma, cond=None):
del sigma, cond
return jnp.ones_like(x)

guided_denoiser = superresolve(
dummy_denoiser, guidance_input=jnp.array(0.0)
_dummy_denoiser, {"observed_slices": jnp.array(0.0)}
)
denoised = guided_denoiser(jnp.ones(test_dim), jnp.array(0.1))
denoised = guided_denoiser(jnp.ones(test_dim), jnp.array(0.1), None)
guided_elements = denoised[superresolve.slices]
self.assertEqual(denoised.shape, test_dim)
self.assertEqual(guided_elements.shape, guide_shape)
Expand Down
Loading

0 comments on commit 61afc04

Please sign in to comment.