From 32b9447f693d3ff4d89087a8cc79a8fbeced2848 Mon Sep 17 00:00:00 2001 From: Amr Kayid Date: Sat, 25 May 2024 03:33:39 +0000 Subject: [PATCH] QoL improvements --- configs/ddim.yaml | 17 +++++- fanan/data/tf_data.py | 3 +- fanan/modeling/architectures/__init__.py | 61 ++++--------------- fanan/modeling/architectures/base.py | 25 ++++++++ fanan/modeling/architectures/ddim.py | 8 ++- fanan/modeling/architectures/registry.py | 19 ++++++ .../modeling/modules/attentions/attention.py | 2 +- 7 files changed, 80 insertions(+), 55 deletions(-) create mode 100644 fanan/modeling/architectures/base.py create mode 100644 fanan/modeling/architectures/registry.py diff --git a/configs/ddim.yaml b/configs/ddim.yaml index b06c4cb..0c4dd5b 100644 --- a/configs/ddim.yaml +++ b/configs/ddim.yaml @@ -26,11 +26,24 @@ arch: embedding_dim: 32 embedding_max_frequency: 1000.0 diffusion: - diffusion_steps: 10 + diffusion_steps: 80 + +optimization: + optimizer_type: "adamw" + optimizer_kwargs: + b1: 0.9 + b2: 0.999 + eps: 1.0e-8 + weight_decay: 1.0e-4 + lr_schedule: + schedule_type: "constant_warmup" + lr_kwargs: + value: 1.0e-3 + warmup_steps: 128 training: total_steps: 10_000 - eval_every_steps: 100 + eval_every_steps: 1000 diff --git a/fanan/data/tf_data.py b/fanan/data/tf_data.py index aa59319..49e9ffc 100644 --- a/fanan/data/tf_data.py +++ b/fanan/data/tf_data.py @@ -37,7 +37,8 @@ def get_dataset(self) -> Any: def get_dataset_iterator(self, split: str = "train") -> Any: if self._config.data.batch_size % jax.device_count() > 0: raise ValueError( - f"batch size {self._config.data.batch_size} must be divisible by the number of devices {jax.device_count()}" + f"batch size {self._config.data.batch_size} must be divisible " + f"by the number of devices {jax.device_count()}" ) batch_size = self._config.data.batch_size // jax.process_count() diff --git a/fanan/modeling/architectures/__init__.py b/fanan/modeling/architectures/__init__.py index c7780d1..feffcc6 100644 --- a/fanan/modeling/architectures/__init__.py +++ b/fanan/modeling/architectures/__init__.py @@ -1,48 +1,13 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any - -import flax.linen as nn -import jax -from jax.sharding import PartitionSpec as PS - -from fanan.config import Config - -_ARCHITECTURES: dict[str, Any] = {} # registry - - -def register_architecture(cls): - _ARCHITECTURES[cls.__name__.lower()] = cls - return cls - - -# class Architecture(ABC, nn.Module): -class Architecture(ABC): - """Base class for all architectures.""" - - def __init__(self, config: Config) -> None: - self._config = config - - @property - def config(self) -> Config: - return self._config - - # @abstractmethod - # def __call__( - # self, batch: dict[str, jax.Array], training: bool - # ) -> dict[str, jax.Array]: - # pass - - # @abstractmethod - # def shard(self, ps: PS) -> tuple[Architecture, PS]: - # pass - - -from fanan.modeling.architectures.ddim import * # isort:skip -# from fanan.modeling.architectures.ddpm import * # isort:skip - - -def get_architecture(config: Config) -> Architecture: - assert config.arch.architecture_name, "Arch config must specify 'architecture'." - return _ARCHITECTURES[config.arch.architecture_name.lower()](config) +__all__ = [ + "Architecture", + "register_architecture", + "get_architecture", + "DDIM", +] + +from fanan.modeling.architectures.base import Architecture +from fanan.modeling.architectures.ddim import DDIM +from fanan.modeling.architectures.registry import ( + get_architecture, + register_architecture, +) diff --git a/fanan/modeling/architectures/base.py b/fanan/modeling/architectures/base.py new file mode 100644 index 0000000..fdaf903 --- /dev/null +++ b/fanan/modeling/architectures/base.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from fanan.config import Config + + +# class Architecture(ABC, nn.Module): +class Architecture: # (ABC): + """Base class for all architectures.""" + + def __init__(self, config: Config) -> None: + self._config = config + + @property + def config(self) -> Config: + return self._config + + # @abstractmethod + # def __call__( + # self, batch: dict[str, jax.Array], training: bool + # ) -> dict[str, jax.Array]: + # pass + + # @abstractmethod + # def shard(self, ps: PS) -> tuple[Architecture, PS]: + # pass diff --git a/fanan/modeling/architectures/ddim.py b/fanan/modeling/architectures/ddim.py index 5ba00f8..c9930ca 100644 --- a/fanan/modeling/architectures/ddim.py +++ b/fanan/modeling/architectures/ddim.py @@ -10,7 +10,8 @@ from ml_collections.config_dict import ConfigDict from fanan.config.base import ArchitectureConfig, Config -from fanan.modeling.architectures import Architecture, register_architecture +from fanan.modeling.architectures.base import Architecture +from fanan.modeling.architectures.registry import register_architecture from fanan.modeling.modules.state import TrainState from fanan.modeling.modules.unet import UNet from fanan.optimization import lr_schedules, optimizers @@ -209,7 +210,8 @@ def _create_optimizer(self): return optimizer, lr_schedule def _loss(self, predictions: jnp.ndarray, targets: jnp.ndarray): - return optax.l2_loss(predictions, targets).mean() # type: + # l1 loss / mean_absolute_error + return jnp.abs(predictions - targets).mean() @partial(jax.jit, static_argnums=(0,)) def _train_step(self, state, batch, rng): @@ -243,7 +245,7 @@ def train_step(self, batch): self.global_step += 1 return loss - # @partial(jax.jit, static_argnums=(0,5)) + @partial(jax.jit, static_argnums=(0, 5)) def _eval_step(self, state, params, rng, batch, diffusion_steps: int): variables = {"params": params, "batch_stats": state.batch_stats} generated_images = state.apply_fn(variables, rng, batch.shape, diffusion_steps, method=DDIMModel.generate) diff --git a/fanan/modeling/architectures/registry.py b/fanan/modeling/architectures/registry.py new file mode 100644 index 0000000..4745151 --- /dev/null +++ b/fanan/modeling/architectures/registry.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import Any + +from fanan.config import Config +from fanan.modeling.architectures.base import Architecture + +_ARCHITECTURES: dict[str, Any] = {} # registry + + +def register_architecture(cls): + global _ARCHITECTURES + _ARCHITECTURES[cls.__name__.lower()] = cls + return cls + + +def get_architecture(config: Config) -> Architecture: + assert config.arch.architecture_name, "Arch config must specify 'architecture'." + return _ARCHITECTURES[config.arch.architecture_name.lower()](config) diff --git a/fanan/modeling/modules/attentions/attention.py b/fanan/modeling/modules/attentions/attention.py index ed95027..5ae841a 100644 --- a/fanan/modeling/modules/attentions/attention.py +++ b/fanan/modeling/modules/attentions/attention.py @@ -85,7 +85,7 @@ def apply_attention_dot( attention_scores = (attention_scores * scale).astype(query_states.dtype) # [batch, num_heads, seq, seq] # TODO: add mask - mask = None + # mask = None attention_probs = jax.nn.softmax(attention_scores, axis=-1)