From ca0a20a8f84b1988158e92331d5b071af1962e38 Mon Sep 17 00:00:00 2001 From: Amr Kayid Date: Thu, 23 May 2024 17:40:41 -0400 Subject: [PATCH] DDIM improvements (#2) * update data pipeline * fix bugs --- configs/ddim.yaml | 14 +-- fanan/config/base.py | 4 +- fanan/data/tf_data.py | 137 ++++++++++++--------------- fanan/fanan.py | 8 +- fanan/modeling/architectures/ddim.py | 34 ++++--- fanan/modeling/modules/embedding.py | 11 ++- fanan/modeling/modules/unet.py | 78 +++++++++------ fanan/utils/image_utils.py | 56 +++++++++++ 8 files changed, 205 insertions(+), 137 deletions(-) create mode 100644 fanan/utils/image_utils.py diff --git a/configs/ddim.yaml b/configs/ddim.yaml index 8c22707..f97578c 100644 --- a/configs/ddim.yaml +++ b/configs/ddim.yaml @@ -12,19 +12,21 @@ mesh: data: dataset_name: "oxford_flowers102" - batch_size: 8 - cache: False - image_size: 32 + image_size: [64, 64] num_channels: 3 + batch_size: 64 + cache: False arch: architecture_name: "ddim" - image_size: [32, 32] - feature_stages: [32, 64] + image_size: [64, 64] + feature_stages: [32, 64, 96, 128] block_depth: 2 + embedding_dim: 32 + embedding_max_frequency: 1000.0 diffusion: - diffusion_steps: 128 + diffusion_steps: 10 training: diff --git a/fanan/config/base.py b/fanan/config/base.py index 4816beb..8ca91a7 100644 --- a/fanan/config/base.py +++ b/fanan/config/base.py @@ -33,10 +33,10 @@ class DataConfig(ConfigDict): def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None: super().__init__(initial_dictionary=initial_dictionary, **kwargs) self.dataset_name: str = "mnist" + self.image_size: list[int] = [64, 64] + self.num_channels: int = 3 self.batch_size: int = 64 self.cache: bool = False - self.image_size: int = 512 - self.num_channels: int = 3 class DiffusionConfig(ConfigDict): diff --git a/fanan/data/tf_data.py b/fanan/data/tf_data.py index ec59d0c..aa59319 100644 --- a/fanan/data/tf_data.py +++ b/fanan/data/tf_data.py @@ -1,88 +1,67 @@ -import logging +from functools import partial from typing import Any import jax import tensorflow as tf import tensorflow_datasets as tfds - -from fanan.config.base import Config - - -def normalize_to_neg_one_to_one(img): - return img * 2 - 1 - - -def crop_and_resize(image: tf.Tensor, resolution: int = 64) -> tf.Tensor: - height, width = tf.shape(image)[0], tf.shape(image)[1] - crop_size = tf.minimum(height, width) - # image = image[ - # (height - crop) // 2 : (height + crop) // 2, - # (width - crop) // 2 : (width + crop) // 2, - # ] - image = tf.image.crop_to_bounding_box( - image=image, - offset_height=(height - crop_size) // 2, - offset_width=(width - crop_size) // 2, - target_height=crop_size, - target_width=crop_size, - ) - image = tf.image.resize( - image, - size=(resolution, resolution), - antialias=True, - method=tf.image.ResizeMethod.BICUBIC, - ) - return tf.cast(image, tf.uint8) - - -def get_dataset_iterator(config: Config, split: str = "train") -> Any: - if config.data.batch_size % jax.device_count() > 0: - raise ValueError( - f"batch size {config.data.batch_size} must be divisible by the number of devices {jax.device_count()}" +from ml_collections.config_dict import ConfigDict + +from fanan.config.base import Config, DataConfig +from fanan.utils.image_utils import process_image + + +class DefaultDataConfig(DataConfig): + def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None: + super().__init__(initial_dictionary=initial_dictionary, **kwargs) + self.dataset_name: str = "oxford_flowers102" + self.image_size: list[int] = [64, 64] + self.num_channels: int = 3 + self.batch_size: int = 64 + self.cache: bool = False + self.update(ConfigDict(initial_dictionary).copy_and_resolve_references()) + + +class Dataset: + def __init__(self, config: Config): + self._config = config + self._config.data = DefaultDataConfig(self._config.data) + self.train_iter, self.val_iter = self.get_dataset() + + def get_dataset(self) -> Any: + # train_iter = self.get_dataset_iterator(split="train") + # val_iter = self.get_dataset_iterator(split="test") + train_iter = self.get_dataset_iterator(split="train[:80%]+validation[:80%]+test[:80%]") + val_iter = self.get_dataset_iterator(split="train[80%:]+validation[80%:]+test[80%:]") + return train_iter, val_iter + + def get_dataset_iterator(self, split: str = "train") -> Any: + if self._config.data.batch_size % jax.device_count() > 0: + raise ValueError( + f"batch size {self._config.data.batch_size} must be divisible by the number of devices {jax.device_count()}" + ) + + batch_size = self._config.data.batch_size // jax.process_count() + + platform = jax.local_devices()[0].platform + input_dtype = ( + (tf.bfloat16 if platform == "tpu" else tf.float16) if self._config.training.half_precision else tf.float32 ) - batch_size = config.data.batch_size // jax.process_count() - - platform = jax.local_devices()[0].platform - input_dtype = (tf.bfloat16 if platform == "tpu" else tf.float16) if config.training.half_precision else tf.float32 - - dataset_builder = tfds.builder(config.data.dataset_name) - dataset_builder.download_and_prepare() - - def preprocess_fn(d: dict) -> dict[str, Any]: - image = d.get("image") - image = crop_and_resize(image=image, resolution=config.data.image_size) - # image = tf.image.flip_left_right(image) - image = tf.image.convert_image_dtype(image, input_dtype) - # return {"image": image} - return image - - # create split for current process - num_examples = dataset_builder.info.splits[split].num_examples - logging.info(f"Total {split=} examples: {num_examples=}") - split_size = num_examples // jax.process_count() - logging.info(f"Split size: {split_size=}") - start = jax.process_index() * split_size - split = f"{split}[{start}:{start + split_size}]" - - ds = dataset_builder.as_dataset(split=split) - options = tf.data.Options() - options.threading.private_threadpool_size = 48 - ds.with_options(options) - - ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) - if config.data.cache: - ds = ds.cache() - - ds = ds.repeat() - ds = ds.shuffle(16 * batch_size, seed=config.fanan.seed) - ds = ds.batch(batch_size, drop_remainder=True) - ds = ds.prefetch(tf.data.experimental.AUTOTUNE) - - return iter(tfds.as_numpy(ds)) + ds = tfds.load(self._config.data.dataset_name, split=split, shuffle_files=True) + ds = ds.map( + partial( + process_image, + resolution=self._config.data.image_size, + input_dtype=input_dtype, + ), + num_parallel_calls=tf.data.experimental.AUTOTUNE, + ) + if self._config.data.cache: + ds = ds.cache() + ds = ds.repeat() + ds = ds.shuffle(16 * batch_size, seed=self._config.fanan.seed) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.prefetch(tf.data.experimental.AUTOTUNE) -def get_dataset(config: Config) -> Any: - train_ds = get_dataset_iterator(config, split="train") - val_ds = get_dataset_iterator(config, split="test") - return train_ds, val_ds + return iter(tfds.as_numpy(ds)) diff --git a/fanan/fanan.py b/fanan/fanan.py index d3d5730..4496975 100644 --- a/fanan/fanan.py +++ b/fanan/fanan.py @@ -5,7 +5,7 @@ from fanan.config import Config from fanan.core.cortex import Cortex -from fanan.data.tf_data import get_dataset +from fanan.data.tf_data import Dataset from fanan.utils.parser import parse_args logging.basicConfig( @@ -32,12 +32,12 @@ def main() -> None: config = Config.read_config_from_yaml(args.config_path) logging.info(f"{config=}") - train_dl, val_dl = get_dataset(config) + dataset = Dataset(config=config) cortex = Cortex(config) cortex.train( - train_dataloader_iter=train_dl, - val_dataloader_iter=val_dl, + train_dataloader_iter=dataset.train_iter, + val_dataloader_iter=dataset.val_iter, ) diff --git a/fanan/modeling/architectures/ddim.py b/fanan/modeling/architectures/ddim.py index c2f5020..5ba00f8 100644 --- a/fanan/modeling/architectures/ddim.py +++ b/fanan/modeling/architectures/ddim.py @@ -65,8 +65,8 @@ def setup(self): embedding_max_frequency=cfg.embedding_max_frequency, ) - def __call__(self, images, rng, train: bool): - images = self.normalizer(images, use_running_average=not train) + def __call__(self, images, rng, is_training: bool): + images = self.normalizer(images, use_running_average=not is_training) rng_noises, rng_times = jax.random.split(rng) noises = jax.random.normal(rng_noises, images.shape, images.dtype) @@ -74,7 +74,7 @@ def __call__(self, images, rng, train: bool): noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) noisy_images = signal_rates * images + noise_rates * noises - pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, train=train) + pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, is_training=is_training) return noises, images, pred_noises, pred_images def diffusion_schedule( @@ -94,8 +94,8 @@ def diffusion_schedule( return noise_rates, signal_rates - def denoise(self, noisy_images, noise_rates, signal_rates, train: bool): - pred_noises = self.network(noisy_images, noise_rates**2) + def denoise(self, noisy_images, noise_rates, signal_rates, is_training: bool): + pred_noises = self.network(noisy_images, noise_rates**2, is_training=is_training) pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates return pred_noises, pred_images @@ -113,7 +113,7 @@ def reverse_diffusion(self, initial_noise, diffusion_steps): ones = jnp.ones((n_images, 1, 1, 1), dtype=initial_noise.dtype) diffusion_times = ones - step * step_size noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) - pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, train=False) + pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, is_training=False) next_diffusion_times = diffusion_times - step_size next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times) @@ -156,8 +156,7 @@ def initialization_input(self): image_size = self.config.data.image_size shape = ( self.config.data.batch_size, - image_size, - image_size, + *image_size, self.config.data.num_channels, ) return jnp.ones(shape, dtype=jnp.float32) @@ -170,7 +169,7 @@ def _create_state(self): key_init, self.initialization_input, key_diffusion, - train=True, + is_training=False, ) tx, lr_schedule = self._create_optimizer() @@ -188,12 +187,15 @@ def _create_state(self): tabulate_fn = nn.tabulate( DDIMModel(config=self._config.arch), key_init, - show_repeated=True, - compute_flops=True, - compute_vjp_flops=True, ) - print(tabulate_fn(self.initialization_input, key_diffusion, False)) + print( + tabulate_fn( + images=self.initialization_input, + rng=key_diffusion, + is_training=False, + ) + ) return state, lr_schedule @@ -213,7 +215,11 @@ def _loss(self, predictions: jnp.ndarray, targets: jnp.ndarray): def _train_step(self, state, batch, rng): def loss_fn(params): outputs, mutated_vars = state.apply_fn( - {"params": params, "batch_stats": state.batch_stats}, batch, rng, train=True, mutable=["batch_stats"] + {"params": params, "batch_stats": state.batch_stats}, + batch, + rng, + is_training=True, + mutable=["batch_stats"], ) noises, images, pred_noises, pred_images = outputs diff --git a/fanan/modeling/modules/embedding.py b/fanan/modeling/modules/embedding.py index 307a076..81d083f 100644 --- a/fanan/modeling/modules/embedding.py +++ b/fanan/modeling/modules/embedding.py @@ -13,7 +13,7 @@ def setup(self): start = jnp.log(self.embedding_min_frequency) stop = jnp.log(self.embedding_max_frequency) frequencies = jnp.exp(jnp.linspace(start, stop, half_embedding_dim)) - self.angular_speeds = 2.0 * jnp.pi * frequencies + self.angular_speeds = (2.0 * jnp.pi * frequencies).astype(self.dtype) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: positional_embeddings = jnp.concatenate( @@ -29,10 +29,17 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: class TimeEmbedding(nn.Module): time_embedding_dim: int sinusoidal_embedding_dim: int + sinusoidal_embedding_min_frequency: float = 1.0 + sinusoidal_embedding_max_frequency: float = 10_000.0 dtype: jnp.dtype = jnp.float32 def setup(self): - self.positional_embedding = SinusoidalPositionalEmbedding(self.sinusoidal_embedding_dim, dtype=self.dtype) + self.positional_embedding = SinusoidalPositionalEmbedding( + embedding_dim=self.sinusoidal_embedding_dim, + embedding_min_frequencys=self.sinusoidal_embedding_min_frequency, + embedding_max_frequencys=self.sinusoidal_embedding_max_frequency, + dtype=self.dtype, + ) self.dense1 = nn.Dense(self.time_embedding_dim, dtype=self.dtype) self.dense2 = nn.Dense(self.time_embedding_dim, dtype=self.dtype) diff --git a/fanan/modeling/modules/unet.py b/fanan/modeling/modules/unet.py index 0976165..4743f9b 100644 --- a/fanan/modeling/modules/unet.py +++ b/fanan/modeling/modules/unet.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple +from typing import Any, Tuple import flax.linen as nn import jax @@ -9,34 +9,42 @@ class UNetResidualBlock(nn.Module): output_channels_width: int - num_groups: Optional[int] = 8 dtype: Any = jnp.float32 def setup(self): - self.conv1 = nn.Conv(self.output_channels_width, kernel_size=(1, 1), name="conv1") - self.conv2 = nn.Conv(self.output_channels_width, kernel_size=(3, 3), padding="SAME", name="conv2") - self.conv3 = nn.Conv(self.output_channels_width, kernel_size=(3, 3), padding="SAME", name="conv3") - self.group_norm = nn.GroupNorm( - num_groups=self.num_groups, - epsilon=1e-5, - use_bias=False, - use_scale=False, - dtype=self.dtype, + self.conv1 = nn.Conv(features=self.output_channels_width, kernel_size=(1, 1), name="conv1") + self.bn = nn.BatchNorm(use_bias=False, use_scale=False) + self.conv2 = nn.Conv( + features=self.output_channels_width, + kernel_size=(3, 3), + padding="SAME", + name="conv2", + ) + self.conv3 = nn.Conv( + features=self.output_channels_width, + kernel_size=(3, 3), + padding="SAME", + name="conv3", ) @nn.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - input_width = x.shape[-1] - - residual = self.conv1(x) if input_width != self.output_channels_width else x + def __call__( + self, + x: jnp.ndarray, + is_training: bool, + ) -> jnp.ndarray: + input_width = x.shape[3] + residual = x if input_width == self.output_channels_width else self.conv1(x) - x = self.group_norm(x) - x = nn.swish(x) + x = self.bn( + x, + use_running_average=not is_training, + ) x = self.conv2(x) x = nn.swish(x) x = self.conv3(x) - x = x + residual + x += residual return x @@ -54,9 +62,15 @@ def setup(self): for i in range(self.block_depth) ] - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + def __call__( + self, + x: jnp.ndarray, + skips: list[jnp.ndarray], + is_training: bool, + ) -> tuple[jnp.ndarray, list[jnp.ndarray]]: for block in self.residual_blocks: - x = block(x) + x = block(x, is_training=is_training) + skips.append(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) return x @@ -80,11 +94,16 @@ def upsample2d(self, x: jnp.ndarray, scale: int = 2) -> jnp.ndarray: x = jax.image.resize(x, shape=upsampled_shape, method="bilinear") return x - def __call__(self, x: jnp.ndarray, skip: jnp.ndarray) -> jnp.ndarray: + def __call__( + self, + x: jnp.ndarray, + skips: list[jnp.ndarray], + is_training: bool, + ) -> jnp.ndarray: x = self.upsample2d(x) - x = jnp.concatenate([x, skip], axis=-1) for block in self.residual_blocks: - x = block(x) + x = jnp.concatenate([x, skips.pop()], axis=-1) + x = block(x, is_training=is_training) return x @@ -117,13 +136,13 @@ def __call__( self, noisy_images: jnp.ndarray, noise_variances: jnp.ndarray, + is_training: bool = True, ) -> jnp.ndarray: embedding = self.sinusoidal_embedding(noise_variances) # TODO: util function for this? upsampled_shape = ( noisy_images.shape[0], - self.image_size[0], - self.image_size[1], + *self.image_size, self.embedding_dim, ) embedding = jax.image.resize(embedding, upsampled_shape, method="nearest") @@ -133,14 +152,13 @@ def __call__( skips = [] for block in self.down_blocks: - skips.append(x) - x = block(x) + x = block(x, skips, is_training) for block in self.residual_blocks: - x = block(x) + x = block(x, is_training) - for block, skip in zip(self.up_blocks, reversed(skips)): - x = block(x, skip) + for block in self.up_blocks: + x = block(x, skips, is_training) outputs = self.conv2(x) return outputs diff --git a/fanan/utils/image_utils.py b/fanan/utils/image_utils.py new file mode 100644 index 0000000..8dbf3ff --- /dev/null +++ b/fanan/utils/image_utils.py @@ -0,0 +1,56 @@ +from typing import Tuple, Union + +import jax +import tensorflow as tf + + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + + +def crop_and_resize(image: tf.Tensor, resolution: tuple[int, int] = (64, 64)) -> tf.Tensor: + height, width = tf.shape(image)[0], tf.shape(image)[1] + crop_size = tf.minimum(height, width) + # image = image[ + # (height - crop) // 2 : (height + crop) // 2, + # (width - crop) // 2 : (width + crop) // 2, + # ] + image = tf.image.crop_to_bounding_box( + image=image, + offset_height=(height - crop_size) // 2, + offset_width=(width - crop_size) // 2, + target_height=crop_size, + target_width=crop_size, + ) + image = tf.image.resize( + image, + size=resolution, + antialias=True, + method=tf.image.ResizeMethod.BICUBIC, + ) + return tf.clip_by_value(image / 255.0, 0.0, 1.0) + + +def process_image( + data: dict[str, tf.Tensor], + resolution: list[int], + input_dtype: tf.DType = tf.float32, +) -> tf.Tensor: + image = data.get("image") + image = crop_and_resize(image=image, resolution=resolution) + # image = normalize_to_neg_one_to_one(image) + image = tf.image.convert_image_dtype(image, input_dtype) + return image + + +def upsample2d(x, scale: Union[int, Tuple[int, int]], method: str = "bilinear"): + b, h, w, c = x.shape + + if isinstance(scale, int): + h_out, w_out = scale * h, scale * w + elif len(scale) == 2: + h_out, w_out = scale[0] * h, scale[1] * w + else: + raise ValueError("scale argument should be either int" "or Tuple[int, int]") + + return jax.image.resize(x, shape=(b, h_out, w_out, c), method=method)