Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDIM improvements #2

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions configs/ddim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@ mesh:

data:
dataset_name: "oxford_flowers102"
batch_size: 8
cache: False
image_size: 32
image_size: [64, 64]
num_channels: 3
batch_size: 64
cache: False


arch:
architecture_name: "ddim"
image_size: [32, 32]
feature_stages: [32, 64]
image_size: [64, 64]
feature_stages: [32, 64, 96, 128]
block_depth: 2
embedding_dim: 32
embedding_max_frequency: 1000.0
diffusion:
diffusion_steps: 128
diffusion_steps: 10


training:
Expand Down
4 changes: 2 additions & 2 deletions fanan/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ class DataConfig(ConfigDict):
def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
super().__init__(initial_dictionary=initial_dictionary, **kwargs)
self.dataset_name: str = "mnist"
self.image_size: list[int] = [64, 64]
self.num_channels: int = 3
self.batch_size: int = 64
self.cache: bool = False
self.image_size: int = 512
self.num_channels: int = 3


class DiffusionConfig(ConfigDict):
Expand Down
137 changes: 58 additions & 79 deletions fanan/data/tf_data.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,67 @@
import logging
from functools import partial
from typing import Any

import jax
import tensorflow as tf
import tensorflow_datasets as tfds

from fanan.config.base import Config


def normalize_to_neg_one_to_one(img):
return img * 2 - 1


def crop_and_resize(image: tf.Tensor, resolution: int = 64) -> tf.Tensor:
height, width = tf.shape(image)[0], tf.shape(image)[1]
crop_size = tf.minimum(height, width)
# image = image[
# (height - crop) // 2 : (height + crop) // 2,
# (width - crop) // 2 : (width + crop) // 2,
# ]
image = tf.image.crop_to_bounding_box(
image=image,
offset_height=(height - crop_size) // 2,
offset_width=(width - crop_size) // 2,
target_height=crop_size,
target_width=crop_size,
)
image = tf.image.resize(
image,
size=(resolution, resolution),
antialias=True,
method=tf.image.ResizeMethod.BICUBIC,
)
return tf.cast(image, tf.uint8)


def get_dataset_iterator(config: Config, split: str = "train") -> Any:
if config.data.batch_size % jax.device_count() > 0:
raise ValueError(
f"batch size {config.data.batch_size} must be divisible by the number of devices {jax.device_count()}"
from ml_collections.config_dict import ConfigDict

from fanan.config.base import Config, DataConfig
from fanan.utils.image_utils import process_image


class DefaultDataConfig(DataConfig):
def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
super().__init__(initial_dictionary=initial_dictionary, **kwargs)
self.dataset_name: str = "oxford_flowers102"
self.image_size: list[int] = [64, 64]
self.num_channels: int = 3
self.batch_size: int = 64
self.cache: bool = False
self.update(ConfigDict(initial_dictionary).copy_and_resolve_references())


class Dataset:
def __init__(self, config: Config):
self._config = config
self._config.data = DefaultDataConfig(self._config.data)
self.train_iter, self.val_iter = self.get_dataset()

def get_dataset(self) -> Any:
# train_iter = self.get_dataset_iterator(split="train")
# val_iter = self.get_dataset_iterator(split="test")
train_iter = self.get_dataset_iterator(split="train[:80%]+validation[:80%]+test[:80%]")
val_iter = self.get_dataset_iterator(split="train[80%:]+validation[80%:]+test[80%:]")
return train_iter, val_iter

def get_dataset_iterator(self, split: str = "train") -> Any:
if self._config.data.batch_size % jax.device_count() > 0:
raise ValueError(
f"batch size {self._config.data.batch_size} must be divisible by the number of devices {jax.device_count()}"
)

batch_size = self._config.data.batch_size // jax.process_count()

platform = jax.local_devices()[0].platform
input_dtype = (
(tf.bfloat16 if platform == "tpu" else tf.float16) if self._config.training.half_precision else tf.float32
)

batch_size = config.data.batch_size // jax.process_count()

platform = jax.local_devices()[0].platform
input_dtype = (tf.bfloat16 if platform == "tpu" else tf.float16) if config.training.half_precision else tf.float32

dataset_builder = tfds.builder(config.data.dataset_name)
dataset_builder.download_and_prepare()

def preprocess_fn(d: dict) -> dict[str, Any]:
image = d.get("image")
image = crop_and_resize(image=image, resolution=config.data.image_size)
# image = tf.image.flip_left_right(image)
image = tf.image.convert_image_dtype(image, input_dtype)
# return {"image": image}
return image

# create split for current process
num_examples = dataset_builder.info.splits[split].num_examples
logging.info(f"Total {split=} examples: {num_examples=}")
split_size = num_examples // jax.process_count()
logging.info(f"Split size: {split_size=}")
start = jax.process_index() * split_size
split = f"{split}[{start}:{start + split_size}]"

ds = dataset_builder.as_dataset(split=split)
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds.with_options(options)

ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if config.data.cache:
ds = ds.cache()

ds = ds.repeat()
ds = ds.shuffle(16 * batch_size, seed=config.fanan.seed)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

return iter(tfds.as_numpy(ds))
ds = tfds.load(self._config.data.dataset_name, split=split, shuffle_files=True)
ds = ds.map(
partial(
process_image,
resolution=self._config.data.image_size,
input_dtype=input_dtype,
),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
if self._config.data.cache:
ds = ds.cache()

ds = ds.repeat()
ds = ds.shuffle(16 * batch_size, seed=self._config.fanan.seed)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

def get_dataset(config: Config) -> Any:
train_ds = get_dataset_iterator(config, split="train")
val_ds = get_dataset_iterator(config, split="test")
return train_ds, val_ds
return iter(tfds.as_numpy(ds))
8 changes: 4 additions & 4 deletions fanan/fanan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fanan.config import Config
from fanan.core.cortex import Cortex
from fanan.data.tf_data import get_dataset
from fanan.data.tf_data import Dataset
from fanan.utils.parser import parse_args

logging.basicConfig(
Expand All @@ -32,12 +32,12 @@ def main() -> None:
config = Config.read_config_from_yaml(args.config_path)
logging.info(f"{config=}")

train_dl, val_dl = get_dataset(config)
dataset = Dataset(config=config)

cortex = Cortex(config)
cortex.train(
train_dataloader_iter=train_dl,
val_dataloader_iter=val_dl,
train_dataloader_iter=dataset.train_iter,
val_dataloader_iter=dataset.val_iter,
)


Expand Down
34 changes: 20 additions & 14 deletions fanan/modeling/architectures/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ def setup(self):
embedding_max_frequency=cfg.embedding_max_frequency,
)

def __call__(self, images, rng, train: bool):
images = self.normalizer(images, use_running_average=not train)
def __call__(self, images, rng, is_training: bool):
images = self.normalizer(images, use_running_average=not is_training)

rng_noises, rng_times = jax.random.split(rng)
noises = jax.random.normal(rng_noises, images.shape, images.dtype)
diffusion_times = jax.random.uniform(rng_times, (images.shape[0], 1, 1, 1), images.dtype)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
noisy_images = signal_rates * images + noise_rates * noises

pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, train=train)
pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, is_training=is_training)
return noises, images, pred_noises, pred_images

def diffusion_schedule(
Expand All @@ -94,8 +94,8 @@ def diffusion_schedule(

return noise_rates, signal_rates

def denoise(self, noisy_images, noise_rates, signal_rates, train: bool):
pred_noises = self.network(noisy_images, noise_rates**2)
def denoise(self, noisy_images, noise_rates, signal_rates, is_training: bool):
pred_noises = self.network(noisy_images, noise_rates**2, is_training=is_training)
pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
return pred_noises, pred_images

Expand All @@ -113,7 +113,7 @@ def reverse_diffusion(self, initial_noise, diffusion_steps):
ones = jnp.ones((n_images, 1, 1, 1), dtype=initial_noise.dtype)
diffusion_times = ones - step * step_size
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, train=False)
pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, is_training=False)

next_diffusion_times = diffusion_times - step_size
next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times)
Expand Down Expand Up @@ -156,8 +156,7 @@ def initialization_input(self):
image_size = self.config.data.image_size
shape = (
self.config.data.batch_size,
image_size,
image_size,
*image_size,
self.config.data.num_channels,
)
return jnp.ones(shape, dtype=jnp.float32)
Expand All @@ -170,7 +169,7 @@ def _create_state(self):
key_init,
self.initialization_input,
key_diffusion,
train=True,
is_training=False,
)

tx, lr_schedule = self._create_optimizer()
Expand All @@ -188,12 +187,15 @@ def _create_state(self):
tabulate_fn = nn.tabulate(
DDIMModel(config=self._config.arch),
key_init,
show_repeated=True,
compute_flops=True,
compute_vjp_flops=True,
)

print(tabulate_fn(self.initialization_input, key_diffusion, False))
print(
tabulate_fn(
images=self.initialization_input,
rng=key_diffusion,
is_training=False,
)
)

return state, lr_schedule

Expand All @@ -213,7 +215,11 @@ def _loss(self, predictions: jnp.ndarray, targets: jnp.ndarray):
def _train_step(self, state, batch, rng):
def loss_fn(params):
outputs, mutated_vars = state.apply_fn(
{"params": params, "batch_stats": state.batch_stats}, batch, rng, train=True, mutable=["batch_stats"]
{"params": params, "batch_stats": state.batch_stats},
batch,
rng,
is_training=True,
mutable=["batch_stats"],
)
noises, images, pred_noises, pred_images = outputs

Expand Down
11 changes: 9 additions & 2 deletions fanan/modeling/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setup(self):
start = jnp.log(self.embedding_min_frequency)
stop = jnp.log(self.embedding_max_frequency)
frequencies = jnp.exp(jnp.linspace(start, stop, half_embedding_dim))
self.angular_speeds = 2.0 * jnp.pi * frequencies
self.angular_speeds = (2.0 * jnp.pi * frequencies).astype(self.dtype)

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
positional_embeddings = jnp.concatenate(
Expand All @@ -29,10 +29,17 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
class TimeEmbedding(nn.Module):
time_embedding_dim: int
sinusoidal_embedding_dim: int
sinusoidal_embedding_min_frequency: float = 1.0
sinusoidal_embedding_max_frequency: float = 10_000.0
dtype: jnp.dtype = jnp.float32

def setup(self):
self.positional_embedding = SinusoidalPositionalEmbedding(self.sinusoidal_embedding_dim, dtype=self.dtype)
self.positional_embedding = SinusoidalPositionalEmbedding(
embedding_dim=self.sinusoidal_embedding_dim,
embedding_min_frequencys=self.sinusoidal_embedding_min_frequency,
embedding_max_frequencys=self.sinusoidal_embedding_max_frequency,
dtype=self.dtype,
)
self.dense1 = nn.Dense(self.time_embedding_dim, dtype=self.dtype)
self.dense2 = nn.Dense(self.time_embedding_dim, dtype=self.dtype)

Expand Down
Loading
Loading