Skip to content

Commit

Permalink
DDIM improvements (#2)
Browse files Browse the repository at this point in the history
* update data pipeline

* fix bugs
  • Loading branch information
AmrMKayid authored May 23, 2024
1 parent 5a03ba5 commit ca0a20a
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 137 deletions.
14 changes: 8 additions & 6 deletions configs/ddim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@ mesh:

data:
dataset_name: "oxford_flowers102"
batch_size: 8
cache: False
image_size: 32
image_size: [64, 64]
num_channels: 3
batch_size: 64
cache: False


arch:
architecture_name: "ddim"
image_size: [32, 32]
feature_stages: [32, 64]
image_size: [64, 64]
feature_stages: [32, 64, 96, 128]
block_depth: 2
embedding_dim: 32
embedding_max_frequency: 1000.0
diffusion:
diffusion_steps: 128
diffusion_steps: 10


training:
Expand Down
4 changes: 2 additions & 2 deletions fanan/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ class DataConfig(ConfigDict):
def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
super().__init__(initial_dictionary=initial_dictionary, **kwargs)
self.dataset_name: str = "mnist"
self.image_size: list[int] = [64, 64]
self.num_channels: int = 3
self.batch_size: int = 64
self.cache: bool = False
self.image_size: int = 512
self.num_channels: int = 3


class DiffusionConfig(ConfigDict):
Expand Down
137 changes: 58 additions & 79 deletions fanan/data/tf_data.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,67 @@
import logging
from functools import partial
from typing import Any

import jax
import tensorflow as tf
import tensorflow_datasets as tfds

from fanan.config.base import Config


def normalize_to_neg_one_to_one(img):
return img * 2 - 1


def crop_and_resize(image: tf.Tensor, resolution: int = 64) -> tf.Tensor:
height, width = tf.shape(image)[0], tf.shape(image)[1]
crop_size = tf.minimum(height, width)
# image = image[
# (height - crop) // 2 : (height + crop) // 2,
# (width - crop) // 2 : (width + crop) // 2,
# ]
image = tf.image.crop_to_bounding_box(
image=image,
offset_height=(height - crop_size) // 2,
offset_width=(width - crop_size) // 2,
target_height=crop_size,
target_width=crop_size,
)
image = tf.image.resize(
image,
size=(resolution, resolution),
antialias=True,
method=tf.image.ResizeMethod.BICUBIC,
)
return tf.cast(image, tf.uint8)


def get_dataset_iterator(config: Config, split: str = "train") -> Any:
if config.data.batch_size % jax.device_count() > 0:
raise ValueError(
f"batch size {config.data.batch_size} must be divisible by the number of devices {jax.device_count()}"
from ml_collections.config_dict import ConfigDict

from fanan.config.base import Config, DataConfig
from fanan.utils.image_utils import process_image


class DefaultDataConfig(DataConfig):
def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
super().__init__(initial_dictionary=initial_dictionary, **kwargs)
self.dataset_name: str = "oxford_flowers102"
self.image_size: list[int] = [64, 64]
self.num_channels: int = 3
self.batch_size: int = 64
self.cache: bool = False
self.update(ConfigDict(initial_dictionary).copy_and_resolve_references())


class Dataset:
def __init__(self, config: Config):
self._config = config
self._config.data = DefaultDataConfig(self._config.data)
self.train_iter, self.val_iter = self.get_dataset()

def get_dataset(self) -> Any:
# train_iter = self.get_dataset_iterator(split="train")
# val_iter = self.get_dataset_iterator(split="test")
train_iter = self.get_dataset_iterator(split="train[:80%]+validation[:80%]+test[:80%]")
val_iter = self.get_dataset_iterator(split="train[80%:]+validation[80%:]+test[80%:]")
return train_iter, val_iter

def get_dataset_iterator(self, split: str = "train") -> Any:
if self._config.data.batch_size % jax.device_count() > 0:
raise ValueError(
f"batch size {self._config.data.batch_size} must be divisible by the number of devices {jax.device_count()}"
)

batch_size = self._config.data.batch_size // jax.process_count()

platform = jax.local_devices()[0].platform
input_dtype = (
(tf.bfloat16 if platform == "tpu" else tf.float16) if self._config.training.half_precision else tf.float32
)

batch_size = config.data.batch_size // jax.process_count()

platform = jax.local_devices()[0].platform
input_dtype = (tf.bfloat16 if platform == "tpu" else tf.float16) if config.training.half_precision else tf.float32

dataset_builder = tfds.builder(config.data.dataset_name)
dataset_builder.download_and_prepare()

def preprocess_fn(d: dict) -> dict[str, Any]:
image = d.get("image")
image = crop_and_resize(image=image, resolution=config.data.image_size)
# image = tf.image.flip_left_right(image)
image = tf.image.convert_image_dtype(image, input_dtype)
# return {"image": image}
return image

# create split for current process
num_examples = dataset_builder.info.splits[split].num_examples
logging.info(f"Total {split=} examples: {num_examples=}")
split_size = num_examples // jax.process_count()
logging.info(f"Split size: {split_size=}")
start = jax.process_index() * split_size
split = f"{split}[{start}:{start + split_size}]"

ds = dataset_builder.as_dataset(split=split)
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds.with_options(options)

ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if config.data.cache:
ds = ds.cache()

ds = ds.repeat()
ds = ds.shuffle(16 * batch_size, seed=config.fanan.seed)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

return iter(tfds.as_numpy(ds))
ds = tfds.load(self._config.data.dataset_name, split=split, shuffle_files=True)
ds = ds.map(
partial(
process_image,
resolution=self._config.data.image_size,
input_dtype=input_dtype,
),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
if self._config.data.cache:
ds = ds.cache()

ds = ds.repeat()
ds = ds.shuffle(16 * batch_size, seed=self._config.fanan.seed)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

def get_dataset(config: Config) -> Any:
train_ds = get_dataset_iterator(config, split="train")
val_ds = get_dataset_iterator(config, split="test")
return train_ds, val_ds
return iter(tfds.as_numpy(ds))
8 changes: 4 additions & 4 deletions fanan/fanan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fanan.config import Config
from fanan.core.cortex import Cortex
from fanan.data.tf_data import get_dataset
from fanan.data.tf_data import Dataset
from fanan.utils.parser import parse_args

logging.basicConfig(
Expand All @@ -32,12 +32,12 @@ def main() -> None:
config = Config.read_config_from_yaml(args.config_path)
logging.info(f"{config=}")

train_dl, val_dl = get_dataset(config)
dataset = Dataset(config=config)

cortex = Cortex(config)
cortex.train(
train_dataloader_iter=train_dl,
val_dataloader_iter=val_dl,
train_dataloader_iter=dataset.train_iter,
val_dataloader_iter=dataset.val_iter,
)


Expand Down
34 changes: 20 additions & 14 deletions fanan/modeling/architectures/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ def setup(self):
embedding_max_frequency=cfg.embedding_max_frequency,
)

def __call__(self, images, rng, train: bool):
images = self.normalizer(images, use_running_average=not train)
def __call__(self, images, rng, is_training: bool):
images = self.normalizer(images, use_running_average=not is_training)

rng_noises, rng_times = jax.random.split(rng)
noises = jax.random.normal(rng_noises, images.shape, images.dtype)
diffusion_times = jax.random.uniform(rng_times, (images.shape[0], 1, 1, 1), images.dtype)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
noisy_images = signal_rates * images + noise_rates * noises

pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, train=train)
pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, is_training=is_training)
return noises, images, pred_noises, pred_images

def diffusion_schedule(
Expand All @@ -94,8 +94,8 @@ def diffusion_schedule(

return noise_rates, signal_rates

def denoise(self, noisy_images, noise_rates, signal_rates, train: bool):
pred_noises = self.network(noisy_images, noise_rates**2)
def denoise(self, noisy_images, noise_rates, signal_rates, is_training: bool):
pred_noises = self.network(noisy_images, noise_rates**2, is_training=is_training)
pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
return pred_noises, pred_images

Expand All @@ -113,7 +113,7 @@ def reverse_diffusion(self, initial_noise, diffusion_steps):
ones = jnp.ones((n_images, 1, 1, 1), dtype=initial_noise.dtype)
diffusion_times = ones - step * step_size
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, train=False)
pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, is_training=False)

next_diffusion_times = diffusion_times - step_size
next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times)
Expand Down Expand Up @@ -156,8 +156,7 @@ def initialization_input(self):
image_size = self.config.data.image_size
shape = (
self.config.data.batch_size,
image_size,
image_size,
*image_size,
self.config.data.num_channels,
)
return jnp.ones(shape, dtype=jnp.float32)
Expand All @@ -170,7 +169,7 @@ def _create_state(self):
key_init,
self.initialization_input,
key_diffusion,
train=True,
is_training=False,
)

tx, lr_schedule = self._create_optimizer()
Expand All @@ -188,12 +187,15 @@ def _create_state(self):
tabulate_fn = nn.tabulate(
DDIMModel(config=self._config.arch),
key_init,
show_repeated=True,
compute_flops=True,
compute_vjp_flops=True,
)

print(tabulate_fn(self.initialization_input, key_diffusion, False))
print(
tabulate_fn(
images=self.initialization_input,
rng=key_diffusion,
is_training=False,
)
)

return state, lr_schedule

Expand All @@ -213,7 +215,11 @@ def _loss(self, predictions: jnp.ndarray, targets: jnp.ndarray):
def _train_step(self, state, batch, rng):
def loss_fn(params):
outputs, mutated_vars = state.apply_fn(
{"params": params, "batch_stats": state.batch_stats}, batch, rng, train=True, mutable=["batch_stats"]
{"params": params, "batch_stats": state.batch_stats},
batch,
rng,
is_training=True,
mutable=["batch_stats"],
)
noises, images, pred_noises, pred_images = outputs

Expand Down
11 changes: 9 additions & 2 deletions fanan/modeling/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setup(self):
start = jnp.log(self.embedding_min_frequency)
stop = jnp.log(self.embedding_max_frequency)
frequencies = jnp.exp(jnp.linspace(start, stop, half_embedding_dim))
self.angular_speeds = 2.0 * jnp.pi * frequencies
self.angular_speeds = (2.0 * jnp.pi * frequencies).astype(self.dtype)

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
positional_embeddings = jnp.concatenate(
Expand All @@ -29,10 +29,17 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
class TimeEmbedding(nn.Module):
time_embedding_dim: int
sinusoidal_embedding_dim: int
sinusoidal_embedding_min_frequency: float = 1.0
sinusoidal_embedding_max_frequency: float = 10_000.0
dtype: jnp.dtype = jnp.float32

def setup(self):
self.positional_embedding = SinusoidalPositionalEmbedding(self.sinusoidal_embedding_dim, dtype=self.dtype)
self.positional_embedding = SinusoidalPositionalEmbedding(
embedding_dim=self.sinusoidal_embedding_dim,
embedding_min_frequencys=self.sinusoidal_embedding_min_frequency,
embedding_max_frequencys=self.sinusoidal_embedding_max_frequency,
dtype=self.dtype,
)
self.dense1 = nn.Dense(self.time_embedding_dim, dtype=self.dtype)
self.dense2 = nn.Dense(self.time_embedding_dim, dtype=self.dtype)

Expand Down
Loading

0 comments on commit ca0a20a

Please sign in to comment.