Skip to content

Commit

Permalink
QoL improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
AmrMKayid committed May 25, 2024
1 parent fa31440 commit 32b9447
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 55 deletions.
17 changes: 15 additions & 2 deletions configs/ddim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,24 @@ arch:
embedding_dim: 32
embedding_max_frequency: 1000.0
diffusion:
diffusion_steps: 10
diffusion_steps: 80

optimization:
optimizer_type: "adamw"
optimizer_kwargs:
b1: 0.9
b2: 0.999
eps: 1.0e-8
weight_decay: 1.0e-4
lr_schedule:
schedule_type: "constant_warmup"
lr_kwargs:
value: 1.0e-3
warmup_steps: 128


training:
total_steps: 10_000
eval_every_steps: 100
eval_every_steps: 1000


3 changes: 2 additions & 1 deletion fanan/data/tf_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def get_dataset(self) -> Any:
def get_dataset_iterator(self, split: str = "train") -> Any:
if self._config.data.batch_size % jax.device_count() > 0:
raise ValueError(
f"batch size {self._config.data.batch_size} must be divisible by the number of devices {jax.device_count()}"
f"batch size {self._config.data.batch_size} must be divisible "
f"by the number of devices {jax.device_count()}"
)

batch_size = self._config.data.batch_size // jax.process_count()
Expand Down
61 changes: 13 additions & 48 deletions fanan/modeling/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,13 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

import flax.linen as nn
import jax
from jax.sharding import PartitionSpec as PS

from fanan.config import Config

_ARCHITECTURES: dict[str, Any] = {} # registry


def register_architecture(cls):
_ARCHITECTURES[cls.__name__.lower()] = cls
return cls


# class Architecture(ABC, nn.Module):
class Architecture(ABC):
"""Base class for all architectures."""

def __init__(self, config: Config) -> None:
self._config = config

@property
def config(self) -> Config:
return self._config

# @abstractmethod
# def __call__(
# self, batch: dict[str, jax.Array], training: bool
# ) -> dict[str, jax.Array]:
# pass

# @abstractmethod
# def shard(self, ps: PS) -> tuple[Architecture, PS]:
# pass


from fanan.modeling.architectures.ddim import * # isort:skip
# from fanan.modeling.architectures.ddpm import * # isort:skip


def get_architecture(config: Config) -> Architecture:
assert config.arch.architecture_name, "Arch config must specify 'architecture'."
return _ARCHITECTURES[config.arch.architecture_name.lower()](config)
__all__ = [
"Architecture",
"register_architecture",
"get_architecture",
"DDIM",
]

from fanan.modeling.architectures.base import Architecture
from fanan.modeling.architectures.ddim import DDIM
from fanan.modeling.architectures.registry import (
get_architecture,
register_architecture,
)
25 changes: 25 additions & 0 deletions fanan/modeling/architectures/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from fanan.config import Config


# class Architecture(ABC, nn.Module):
class Architecture: # (ABC):
"""Base class for all architectures."""

def __init__(self, config: Config) -> None:
self._config = config

@property
def config(self) -> Config:
return self._config

# @abstractmethod
# def __call__(
# self, batch: dict[str, jax.Array], training: bool
# ) -> dict[str, jax.Array]:
# pass

# @abstractmethod
# def shard(self, ps: PS) -> tuple[Architecture, PS]:
# pass
8 changes: 5 additions & 3 deletions fanan/modeling/architectures/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from ml_collections.config_dict import ConfigDict

from fanan.config.base import ArchitectureConfig, Config
from fanan.modeling.architectures import Architecture, register_architecture
from fanan.modeling.architectures.base import Architecture
from fanan.modeling.architectures.registry import register_architecture
from fanan.modeling.modules.state import TrainState
from fanan.modeling.modules.unet import UNet
from fanan.optimization import lr_schedules, optimizers
Expand Down Expand Up @@ -209,7 +210,8 @@ def _create_optimizer(self):
return optimizer, lr_schedule

def _loss(self, predictions: jnp.ndarray, targets: jnp.ndarray):
return optax.l2_loss(predictions, targets).mean() # type:
# l1 loss / mean_absolute_error
return jnp.abs(predictions - targets).mean()

@partial(jax.jit, static_argnums=(0,))
def _train_step(self, state, batch, rng):
Expand Down Expand Up @@ -243,7 +245,7 @@ def train_step(self, batch):
self.global_step += 1
return loss

# @partial(jax.jit, static_argnums=(0,5))
@partial(jax.jit, static_argnums=(0, 5))
def _eval_step(self, state, params, rng, batch, diffusion_steps: int):
variables = {"params": params, "batch_stats": state.batch_stats}
generated_images = state.apply_fn(variables, rng, batch.shape, diffusion_steps, method=DDIMModel.generate)
Expand Down
19 changes: 19 additions & 0 deletions fanan/modeling/architectures/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing import Any

from fanan.config import Config
from fanan.modeling.architectures.base import Architecture

_ARCHITECTURES: dict[str, Any] = {} # registry


def register_architecture(cls):
global _ARCHITECTURES
_ARCHITECTURES[cls.__name__.lower()] = cls
return cls


def get_architecture(config: Config) -> Architecture:
assert config.arch.architecture_name, "Arch config must specify 'architecture'."
return _ARCHITECTURES[config.arch.architecture_name.lower()](config)
2 changes: 1 addition & 1 deletion fanan/modeling/modules/attentions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def apply_attention_dot(
attention_scores = (attention_scores * scale).astype(query_states.dtype) # [batch, num_heads, seq, seq]

# TODO: add mask
mask = None
# mask = None

attention_probs = jax.nn.softmax(attention_scores, axis=-1)

Expand Down

0 comments on commit 32b9447

Please sign in to comment.