Skip to content

Commit

Permalink
DDIM (#1)
Browse files Browse the repository at this point in the history
* Initial `DDIM` implementation

* cleanup and improvements

* format

* downgrade python

* update ci

* update

* bring back cache
  • Loading branch information
AmrMKayid authored May 20, 2024
1 parent 01a38e5 commit d93a660
Show file tree
Hide file tree
Showing 26 changed files with 2,269 additions and 248 deletions.
28 changes: 0 additions & 28 deletions .github/tests.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
id: setup-python
uses: actions/setup-python@v4
with:
python-version: 3.11.9
python-version: 3.10.12

- name: Install Poetry
uses: snok/install-poetry@v1
Expand Down
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# fanan
logs/
models/
output/
data/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.2
rev: v0.4.4
hooks:
# Run the linter.
- id: ruff
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pipx ensurepath
pipx install poetry=1.7

# conda env
conda create -n fanan python=3.11.9 -y --channel conda-forge
conda create -n fanan python=3.10.12 -y --channel conda-forge
conda activate fanan
poetry install

Expand Down
30 changes: 30 additions & 0 deletions configs/ddim.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
fanan:
seed: 37
log_interval: 10


mesh:
n_data_parallel: 1
n_fsdp_parallel: 1
n_sequence_parallel: 1
n_tensors_parallel: 1


data:
dataset_name: "oxford_flowers102"
batch_size: 16
cache: False
image_size: 64
num_channels: 3


arch:
architecture_name: "ddim"
diffusion:
diffusion_steps: 80


training:
total_steps: 10_000
eval_every_steps: 100

4 changes: 2 additions & 2 deletions fanan/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__all__ = ["Config"]
__all__ = ["Config", "_mesh_cfg"]

from fanan.config.base import Config
from fanan.config.base import Config, _mesh_cfg
77 changes: 73 additions & 4 deletions fanan/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class BaseConfig(BaseModel):
"""Base configuration class for all configurations."""

@classmethod
def read_config_from_yaml(cls, file_path: str) -> "BaseConfig":
def read_config_from_yaml(cls, file_path: str):
with open(file_path) as file:
yaml_data = yaml.safe_load(file)

Expand Down Expand Up @@ -36,19 +36,88 @@ class MeshConfig(BaseConfig):
)


class DataConfig(BaseConfig):
"""data configuration class."""

dataset_name: str = "mnist"
batch_size: int = 64
cache: bool = False
image_size: int = 512
num_channels: int = 3


class DiffusionConfig(BaseConfig):
"""Diffuser configuration class."""

timesteps: int = 1000
beta_1: float = 1e-4
beta_T: float = 0.02
timestep_size: float = 0.001
noise_schedule: str = "linear"
ema_decay: float = 0.999
ema_update_every: int = 1
noise_schedule_kwargs: dict = {}
ema_decay_kwargs: dict = {}
diffusion_steps: int = 80


class ArchitectureConfig(BaseConfig):
"""Architecture configuration class."""

architecture_name: str = "ddpm"
diffusion: DiffusionConfig = Field(default_factory=DiffusionConfig)


class LearningRateConfig(BaseConfig):
"""Learning rate configuration class."""

schedule_type: str = "constant_warmup"
lr_kwargs: dict = {
"value": 1e-3,
"warmup_steps": 128,
}


class OptimizationConfig(BaseConfig):
"""Optimization configuration class."""

optimizer_type: str = "adamw"
optimizer_kwargs: dict = {
"b1": 0.9,
"b2": 0.999,
"eps": 1e-8,
}
max_grad_norm: float = 1.0
grad_accum_steps: int = 1
lr_schedule: LearningRateConfig = Field(default_factory=LearningRateConfig)


class TrainingConfig(BaseConfig):
"""training configuration class."""

total_steps: int = 100_000
eval_every_steps: int = 10
loss_type: str = "l1"
half_precision: bool = True
save_and_sample_every: int = 1000
num_sample: int = 64


class FananConfig(BaseConfig):
"""fanan configuration class."""

PRNGKey: int = 0
seed: int = 37
timestamp: str = datetime.now().strftime("%Y%m%d_%H%M%S")
total_steps: int = 1000
log_interval: int = 10
log_every_steps: int = 10


class Config(BaseConfig):
fanan: FananConfig = Field(default_factory=FananConfig)
mesh: MeshConfig = Field(default_factory=MeshConfig)
data: DataConfig = Field(default_factory=DataConfig)
arch: ArchitectureConfig = Field(default_factory=ArchitectureConfig)
optimization: OptimizationConfig = Field(default_factory=OptimizationConfig)
training: TrainingConfig = Field(default_factory=TrainingConfig)


_mesh_cfg = MeshConfig()
Expand Down
Empty file added fanan/core/__init__.py
Empty file.
93 changes: 93 additions & 0 deletions fanan/core/cortex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging

import jax
import numpy as np
import tensorflow as tf
from jax.experimental import mesh_utils
from tqdm import tqdm

from fanan.config import Config
from fanan.modeling.architectures import get_architecture


class Cortex:
"""The Cortex class represents the core component of the neural network
model. It is responsible for initializing the model, training the model,
and storing the model state.
Args:
config (Config): The configuration object containing the model settings.
Attributes:
config (Config): The configuration object containing the model settings.
devices (list): The list of devices used for computation.
mesh (Mesh): The mesh object representing the distributed computation mesh.
architecture (Architecture): The architecture object representing the neural network architecture.
state (TrainState): The train state object representing the current state of the model.
Methods:
__init__(self, config: Config) -> None: Initializes the Cortex object.
initialize_train_state(self) -> None: Initializes the train state of the model.
train(self, dataset) -> None: Trains the model using the given dataset.
"""

def __init__(self, config: Config) -> None:
self.config = config
self.devices = mesh_utils.create_device_mesh(
devices=jax.devices(),
mesh_shape=(
self.config.mesh.n_data_parallel,
self.config.mesh.n_fsdp_parallel,
self.config.mesh.n_sequence_parallel,
self.config.mesh.n_tensors_parallel,
),
contiguous_submeshes=True,
)
logging.info(f"{self.devices=}")

self.mesh = jax.sharding.Mesh(
devices=self.devices,
axis_names=self.config.mesh.mesh_axis_names,
)
logging.info(f"{self.mesh=}")

self.architecture = get_architecture(self.config)
self._writer = tf.summary.create_file_writer("./logs")

def train(self, train_dataloader_iter, val_dataloader_iter) -> None:
"""Trains the model using the given dataset.
This method trains the model using the given dataset by iterating over the dataset
and performing training steps for each batch.
Args:
dataset: The dataset used for training.
Returns:
None
"""

# main loop
losses = []
pbar = tqdm(range(self.config.training.total_steps))
for step in pbar:
batch = next(train_dataloader_iter)
loss = self.architecture.train_step(batch=batch)
losses.append(loss)

if step % self.config.training.eval_every_steps == 0:
batch = next(val_dataloader_iter)
generated_images = self.architecture.eval_step(batch=batch)
with self._writer.as_default():
tf.summary.image("generated", generated_images, step=step, max_outputs=8)

avg_loss = np.mean(losses)
pbar.set_postfix(
{
"step_loss": f"{loss:.5f}",
"avg_loss": f"{avg_loss:.5f}",
}
)

with self._writer.as_default():
tf.summary.scalar("loss", avg_loss, step=step)
Empty file added fanan/data/__init__.py
Empty file.
88 changes: 88 additions & 0 deletions fanan/data/tf_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
from typing import Any

import jax
import tensorflow as tf
import tensorflow_datasets as tfds

from fanan.config.base import Config


def normalize_to_neg_one_to_one(img):
return img * 2 - 1


def crop_and_resize(image: tf.Tensor, resolution: int = 64) -> tf.Tensor:
height, width = tf.shape(image)[0], tf.shape(image)[1]
crop_size = tf.minimum(height, width)
# image = image[
# (height - crop) // 2 : (height + crop) // 2,
# (width - crop) // 2 : (width + crop) // 2,
# ]
image = tf.image.crop_to_bounding_box(
image=image,
offset_height=(height - crop_size) // 2,
offset_width=(width - crop_size) // 2,
target_height=crop_size,
target_width=crop_size,
)
image = tf.image.resize(
image,
size=(resolution, resolution),
antialias=True,
method=tf.image.ResizeMethod.BICUBIC,
)
return tf.cast(image, tf.uint8)


def get_dataset_iterator(config: Config, split: str = "train") -> Any:
if config.data.batch_size % jax.device_count() > 0:
raise ValueError(
f"batch size {config.data.batch_size} must be divisible by the number of devices {jax.device_count()}"
)

batch_size = config.data.batch_size // jax.process_count()

platform = jax.local_devices()[0].platform
input_dtype = (tf.bfloat16 if platform == "tpu" else tf.float16) if config.training.half_precision else tf.float32

dataset_builder = tfds.builder(config.data.dataset_name)
dataset_builder.download_and_prepare()

def preprocess_fn(d: dict) -> dict[str, Any]:
image = d.get("image")
image = crop_and_resize(image=image, resolution=config.data.image_size)
# image = tf.image.flip_left_right(image)
image = tf.image.convert_image_dtype(image, input_dtype)
# return {"image": image}
return image

# create split for current process
num_examples = dataset_builder.info.splits[split].num_examples
logging.info(f"Total {split=} examples: {num_examples=}")
split_size = num_examples // jax.process_count()
logging.info(f"Split size: {split_size=}")
start = jax.process_index() * split_size
split = f"{split}[{start}:{start + split_size}]"

ds = dataset_builder.as_dataset(split=split)
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds.with_options(options)

ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if config.data.cache:
ds = ds.cache()

ds = ds.repeat()
ds = ds.shuffle(16 * batch_size, seed=config.fanan.seed)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

return iter(tfds.as_numpy(ds))


def get_dataset(config: Config) -> Any:
train_ds = get_dataset_iterator(config, split="train")
val_ds = get_dataset_iterator(config, split="test")
return train_ds, val_ds
Loading

0 comments on commit d93a660

Please sign in to comment.