diff --git a/configs/ddim.yaml b/configs/ddim.yaml
index 8c22707..f97578c 100644
--- a/configs/ddim.yaml
+++ b/configs/ddim.yaml
@@ -12,19 +12,21 @@ mesh:
 
 data:
   dataset_name: "oxford_flowers102"
-  batch_size: 8
-  cache: False
-  image_size: 32
+  image_size: [64, 64]
   num_channels: 3
+  batch_size: 64
+  cache: False
 
 
 arch:
   architecture_name: "ddim"
-  image_size: [32, 32]
-  feature_stages: [32, 64]
+  image_size: [64, 64]
+  feature_stages: [32, 64, 96, 128]
   block_depth: 2
+  embedding_dim: 32
+  embedding_max_frequency: 1000.0
   diffusion:
-    diffusion_steps: 128
+    diffusion_steps: 10
 
 
 training:
diff --git a/fanan/config/base.py b/fanan/config/base.py
index 4816beb..8ca91a7 100644
--- a/fanan/config/base.py
+++ b/fanan/config/base.py
@@ -33,10 +33,10 @@ class DataConfig(ConfigDict):
     def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
         super().__init__(initial_dictionary=initial_dictionary, **kwargs)
         self.dataset_name: str = "mnist"
+        self.image_size: list[int] = [64, 64]
+        self.num_channels: int = 3
         self.batch_size: int = 64
         self.cache: bool = False
-        self.image_size: int = 512
-        self.num_channels: int = 3
 
 
 class DiffusionConfig(ConfigDict):
diff --git a/fanan/data/tf_data.py b/fanan/data/tf_data.py
index ec59d0c..aa59319 100644
--- a/fanan/data/tf_data.py
+++ b/fanan/data/tf_data.py
@@ -1,88 +1,67 @@
-import logging
+from functools import partial
 from typing import Any
 
 import jax
 import tensorflow as tf
 import tensorflow_datasets as tfds
-
-from fanan.config.base import Config
-
-
-def normalize_to_neg_one_to_one(img):
-    return img * 2 - 1
-
-
-def crop_and_resize(image: tf.Tensor, resolution: int = 64) -> tf.Tensor:
-    height, width = tf.shape(image)[0], tf.shape(image)[1]
-    crop_size = tf.minimum(height, width)
-    # image = image[
-    #     (height - crop) // 2 : (height + crop) // 2,
-    #     (width - crop) // 2 : (width + crop) // 2,
-    # ]
-    image = tf.image.crop_to_bounding_box(
-        image=image,
-        offset_height=(height - crop_size) // 2,
-        offset_width=(width - crop_size) // 2,
-        target_height=crop_size,
-        target_width=crop_size,
-    )
-    image = tf.image.resize(
-        image,
-        size=(resolution, resolution),
-        antialias=True,
-        method=tf.image.ResizeMethod.BICUBIC,
-    )
-    return tf.cast(image, tf.uint8)
-
-
-def get_dataset_iterator(config: Config, split: str = "train") -> Any:
-    if config.data.batch_size % jax.device_count() > 0:
-        raise ValueError(
-            f"batch size {config.data.batch_size} must be divisible by the number of devices {jax.device_count()}"
+from ml_collections.config_dict import ConfigDict
+
+from fanan.config.base import Config, DataConfig
+from fanan.utils.image_utils import process_image
+
+
+class DefaultDataConfig(DataConfig):
+    def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
+        super().__init__(initial_dictionary=initial_dictionary, **kwargs)
+        self.dataset_name: str = "oxford_flowers102"
+        self.image_size: list[int] = [64, 64]
+        self.num_channels: int = 3
+        self.batch_size: int = 64
+        self.cache: bool = False
+        self.update(ConfigDict(initial_dictionary).copy_and_resolve_references())
+
+
+class Dataset:
+    def __init__(self, config: Config):
+        self._config = config
+        self._config.data = DefaultDataConfig(self._config.data)
+        self.train_iter, self.val_iter = self.get_dataset()
+
+    def get_dataset(self) -> Any:
+        # train_iter = self.get_dataset_iterator(split="train")
+        # val_iter = self.get_dataset_iterator(split="test")
+        train_iter = self.get_dataset_iterator(split="train[:80%]+validation[:80%]+test[:80%]")
+        val_iter = self.get_dataset_iterator(split="train[80%:]+validation[80%:]+test[80%:]")
+        return train_iter, val_iter
+
+    def get_dataset_iterator(self, split: str = "train") -> Any:
+        if self._config.data.batch_size % jax.device_count() > 0:
+            raise ValueError(
+                f"batch size {self._config.data.batch_size} must be divisible by the number of devices {jax.device_count()}"
+            )
+
+        batch_size = self._config.data.batch_size // jax.process_count()
+
+        platform = jax.local_devices()[0].platform
+        input_dtype = (
+            (tf.bfloat16 if platform == "tpu" else tf.float16) if self._config.training.half_precision else tf.float32
         )
 
-    batch_size = config.data.batch_size // jax.process_count()
-
-    platform = jax.local_devices()[0].platform
-    input_dtype = (tf.bfloat16 if platform == "tpu" else tf.float16) if config.training.half_precision else tf.float32
-
-    dataset_builder = tfds.builder(config.data.dataset_name)
-    dataset_builder.download_and_prepare()
-
-    def preprocess_fn(d: dict) -> dict[str, Any]:
-        image = d.get("image")
-        image = crop_and_resize(image=image, resolution=config.data.image_size)
-        # image = tf.image.flip_left_right(image)
-        image = tf.image.convert_image_dtype(image, input_dtype)
-        # return {"image": image}
-        return image
-
-    # create split for current process
-    num_examples = dataset_builder.info.splits[split].num_examples
-    logging.info(f"Total {split=} examples: {num_examples=}")
-    split_size = num_examples // jax.process_count()
-    logging.info(f"Split size: {split_size=}")
-    start = jax.process_index() * split_size
-    split = f"{split}[{start}:{start + split_size}]"
-
-    ds = dataset_builder.as_dataset(split=split)
-    options = tf.data.Options()
-    options.threading.private_threadpool_size = 48
-    ds.with_options(options)
-
-    ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
-    if config.data.cache:
-        ds = ds.cache()
-
-    ds = ds.repeat()
-    ds = ds.shuffle(16 * batch_size, seed=config.fanan.seed)
-    ds = ds.batch(batch_size, drop_remainder=True)
-    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
-
-    return iter(tfds.as_numpy(ds))
+        ds = tfds.load(self._config.data.dataset_name, split=split, shuffle_files=True)
+        ds = ds.map(
+            partial(
+                process_image,
+                resolution=self._config.data.image_size,
+                input_dtype=input_dtype,
+            ),
+            num_parallel_calls=tf.data.experimental.AUTOTUNE,
+        )
+        if self._config.data.cache:
+            ds = ds.cache()
 
+        ds = ds.repeat()
+        ds = ds.shuffle(16 * batch_size, seed=self._config.fanan.seed)
+        ds = ds.batch(batch_size, drop_remainder=True)
+        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
 
-def get_dataset(config: Config) -> Any:
-    train_ds = get_dataset_iterator(config, split="train")
-    val_ds = get_dataset_iterator(config, split="test")
-    return train_ds, val_ds
+        return iter(tfds.as_numpy(ds))
diff --git a/fanan/fanan.py b/fanan/fanan.py
index d3d5730..4496975 100644
--- a/fanan/fanan.py
+++ b/fanan/fanan.py
@@ -5,7 +5,7 @@
 
 from fanan.config import Config
 from fanan.core.cortex import Cortex
-from fanan.data.tf_data import get_dataset
+from fanan.data.tf_data import Dataset
 from fanan.utils.parser import parse_args
 
 logging.basicConfig(
@@ -32,12 +32,12 @@ def main() -> None:
     config = Config.read_config_from_yaml(args.config_path)
     logging.info(f"{config=}")
 
-    train_dl, val_dl = get_dataset(config)
+    dataset = Dataset(config=config)
 
     cortex = Cortex(config)
     cortex.train(
-        train_dataloader_iter=train_dl,
-        val_dataloader_iter=val_dl,
+        train_dataloader_iter=dataset.train_iter,
+        val_dataloader_iter=dataset.val_iter,
     )
 
 
diff --git a/fanan/modeling/architectures/ddim.py b/fanan/modeling/architectures/ddim.py
index c2f5020..5ba00f8 100644
--- a/fanan/modeling/architectures/ddim.py
+++ b/fanan/modeling/architectures/ddim.py
@@ -65,8 +65,8 @@ def setup(self):
             embedding_max_frequency=cfg.embedding_max_frequency,
         )
 
-    def __call__(self, images, rng, train: bool):
-        images = self.normalizer(images, use_running_average=not train)
+    def __call__(self, images, rng, is_training: bool):
+        images = self.normalizer(images, use_running_average=not is_training)
 
         rng_noises, rng_times = jax.random.split(rng)
         noises = jax.random.normal(rng_noises, images.shape, images.dtype)
@@ -74,7 +74,7 @@ def __call__(self, images, rng, train: bool):
         noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
         noisy_images = signal_rates * images + noise_rates * noises
 
-        pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, train=train)
+        pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, is_training=is_training)
         return noises, images, pred_noises, pred_images
 
     def diffusion_schedule(
@@ -94,8 +94,8 @@ def diffusion_schedule(
 
         return noise_rates, signal_rates
 
-    def denoise(self, noisy_images, noise_rates, signal_rates, train: bool):
-        pred_noises = self.network(noisy_images, noise_rates**2)
+    def denoise(self, noisy_images, noise_rates, signal_rates, is_training: bool):
+        pred_noises = self.network(noisy_images, noise_rates**2, is_training=is_training)
         pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
         return pred_noises, pred_images
 
@@ -113,7 +113,7 @@ def reverse_diffusion(self, initial_noise, diffusion_steps):
             ones = jnp.ones((n_images, 1, 1, 1), dtype=initial_noise.dtype)
             diffusion_times = ones - step * step_size
             noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
-            pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, train=False)
+            pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, is_training=False)
 
             next_diffusion_times = diffusion_times - step_size
             next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times)
@@ -156,8 +156,7 @@ def initialization_input(self):
         image_size = self.config.data.image_size
         shape = (
             self.config.data.batch_size,
-            image_size,
-            image_size,
+            *image_size,
             self.config.data.num_channels,
         )
         return jnp.ones(shape, dtype=jnp.float32)
@@ -170,7 +169,7 @@ def _create_state(self):
             key_init,
             self.initialization_input,
             key_diffusion,
-            train=True,
+            is_training=False,
         )
 
         tx, lr_schedule = self._create_optimizer()
@@ -188,12 +187,15 @@ def _create_state(self):
         tabulate_fn = nn.tabulate(
             DDIMModel(config=self._config.arch),
             key_init,
-            show_repeated=True,
-            compute_flops=True,
-            compute_vjp_flops=True,
         )
 
-        print(tabulate_fn(self.initialization_input, key_diffusion, False))
+        print(
+            tabulate_fn(
+                images=self.initialization_input,
+                rng=key_diffusion,
+                is_training=False,
+            )
+        )
 
         return state, lr_schedule
 
@@ -213,7 +215,11 @@ def _loss(self, predictions: jnp.ndarray, targets: jnp.ndarray):
     def _train_step(self, state, batch, rng):
         def loss_fn(params):
             outputs, mutated_vars = state.apply_fn(
-                {"params": params, "batch_stats": state.batch_stats}, batch, rng, train=True, mutable=["batch_stats"]
+                {"params": params, "batch_stats": state.batch_stats},
+                batch,
+                rng,
+                is_training=True,
+                mutable=["batch_stats"],
             )
             noises, images, pred_noises, pred_images = outputs
 
diff --git a/fanan/modeling/modules/embedding.py b/fanan/modeling/modules/embedding.py
index 307a076..81d083f 100644
--- a/fanan/modeling/modules/embedding.py
+++ b/fanan/modeling/modules/embedding.py
@@ -13,7 +13,7 @@ def setup(self):
         start = jnp.log(self.embedding_min_frequency)
         stop = jnp.log(self.embedding_max_frequency)
         frequencies = jnp.exp(jnp.linspace(start, stop, half_embedding_dim))
-        self.angular_speeds = 2.0 * jnp.pi * frequencies
+        self.angular_speeds = (2.0 * jnp.pi * frequencies).astype(self.dtype)
 
     def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
         positional_embeddings = jnp.concatenate(
@@ -29,10 +29,17 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
 class TimeEmbedding(nn.Module):
     time_embedding_dim: int
     sinusoidal_embedding_dim: int
+    sinusoidal_embedding_min_frequency: float = 1.0
+    sinusoidal_embedding_max_frequency: float = 10_000.0
     dtype: jnp.dtype = jnp.float32
 
     def setup(self):
-        self.positional_embedding = SinusoidalPositionalEmbedding(self.sinusoidal_embedding_dim, dtype=self.dtype)
+        self.positional_embedding = SinusoidalPositionalEmbedding(
+            embedding_dim=self.sinusoidal_embedding_dim,
+            embedding_min_frequencys=self.sinusoidal_embedding_min_frequency,
+            embedding_max_frequencys=self.sinusoidal_embedding_max_frequency,
+            dtype=self.dtype,
+        )
         self.dense1 = nn.Dense(self.time_embedding_dim, dtype=self.dtype)
         self.dense2 = nn.Dense(self.time_embedding_dim, dtype=self.dtype)
 
diff --git a/fanan/modeling/modules/unet.py b/fanan/modeling/modules/unet.py
index 0976165..4743f9b 100644
--- a/fanan/modeling/modules/unet.py
+++ b/fanan/modeling/modules/unet.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Tuple
+from typing import Any, Tuple
 
 import flax.linen as nn
 import jax
@@ -9,34 +9,42 @@
 
 class UNetResidualBlock(nn.Module):
     output_channels_width: int
-    num_groups: Optional[int] = 8
     dtype: Any = jnp.float32
 
     def setup(self):
-        self.conv1 = nn.Conv(self.output_channels_width, kernel_size=(1, 1), name="conv1")
-        self.conv2 = nn.Conv(self.output_channels_width, kernel_size=(3, 3), padding="SAME", name="conv2")
-        self.conv3 = nn.Conv(self.output_channels_width, kernel_size=(3, 3), padding="SAME", name="conv3")
-        self.group_norm = nn.GroupNorm(
-            num_groups=self.num_groups,
-            epsilon=1e-5,
-            use_bias=False,
-            use_scale=False,
-            dtype=self.dtype,
+        self.conv1 = nn.Conv(features=self.output_channels_width, kernel_size=(1, 1), name="conv1")
+        self.bn = nn.BatchNorm(use_bias=False, use_scale=False)
+        self.conv2 = nn.Conv(
+            features=self.output_channels_width,
+            kernel_size=(3, 3),
+            padding="SAME",
+            name="conv2",
+        )
+        self.conv3 = nn.Conv(
+            features=self.output_channels_width,
+            kernel_size=(3, 3),
+            padding="SAME",
+            name="conv3",
         )
 
     @nn.compact
-    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
-        input_width = x.shape[-1]
-
-        residual = self.conv1(x) if input_width != self.output_channels_width else x
+    def __call__(
+        self,
+        x: jnp.ndarray,
+        is_training: bool,
+    ) -> jnp.ndarray:
+        input_width = x.shape[3]
+        residual = x if input_width == self.output_channels_width else self.conv1(x)
 
-        x = self.group_norm(x)
-        x = nn.swish(x)
+        x = self.bn(
+            x,
+            use_running_average=not is_training,
+        )
         x = self.conv2(x)
         x = nn.swish(x)
         x = self.conv3(x)
 
-        x = x + residual
+        x += residual
 
         return x
 
@@ -54,9 +62,15 @@ def setup(self):
             for i in range(self.block_depth)
         ]
 
-    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
+    def __call__(
+        self,
+        x: jnp.ndarray,
+        skips: list[jnp.ndarray],
+        is_training: bool,
+    ) -> tuple[jnp.ndarray, list[jnp.ndarray]]:
         for block in self.residual_blocks:
-            x = block(x)
+            x = block(x, is_training=is_training)
+            skips.append(x)
         x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
         return x
 
@@ -80,11 +94,16 @@ def upsample2d(self, x: jnp.ndarray, scale: int = 2) -> jnp.ndarray:
         x = jax.image.resize(x, shape=upsampled_shape, method="bilinear")
         return x
 
-    def __call__(self, x: jnp.ndarray, skip: jnp.ndarray) -> jnp.ndarray:
+    def __call__(
+        self,
+        x: jnp.ndarray,
+        skips: list[jnp.ndarray],
+        is_training: bool,
+    ) -> jnp.ndarray:
         x = self.upsample2d(x)
-        x = jnp.concatenate([x, skip], axis=-1)
         for block in self.residual_blocks:
-            x = block(x)
+            x = jnp.concatenate([x, skips.pop()], axis=-1)
+            x = block(x, is_training=is_training)
         return x
 
 
@@ -117,13 +136,13 @@ def __call__(
         self,
         noisy_images: jnp.ndarray,
         noise_variances: jnp.ndarray,
+        is_training: bool = True,
     ) -> jnp.ndarray:
         embedding = self.sinusoidal_embedding(noise_variances)
         # TODO: util function for this?
         upsampled_shape = (
             noisy_images.shape[0],
-            self.image_size[0],
-            self.image_size[1],
+            *self.image_size,
             self.embedding_dim,
         )
         embedding = jax.image.resize(embedding, upsampled_shape, method="nearest")
@@ -133,14 +152,13 @@ def __call__(
 
         skips = []
         for block in self.down_blocks:
-            skips.append(x)
-            x = block(x)
+            x = block(x, skips, is_training)
 
         for block in self.residual_blocks:
-            x = block(x)
+            x = block(x, is_training)
 
-        for block, skip in zip(self.up_blocks, reversed(skips)):
-            x = block(x, skip)
+        for block in self.up_blocks:
+            x = block(x, skips, is_training)
 
         outputs = self.conv2(x)
         return outputs
diff --git a/fanan/utils/image_utils.py b/fanan/utils/image_utils.py
new file mode 100644
index 0000000..8dbf3ff
--- /dev/null
+++ b/fanan/utils/image_utils.py
@@ -0,0 +1,56 @@
+from typing import Tuple, Union
+
+import jax
+import tensorflow as tf
+
+
+def normalize_to_neg_one_to_one(img):
+    return img * 2 - 1
+
+
+def crop_and_resize(image: tf.Tensor, resolution: tuple[int, int] = (64, 64)) -> tf.Tensor:
+    height, width = tf.shape(image)[0], tf.shape(image)[1]
+    crop_size = tf.minimum(height, width)
+    # image = image[
+    #     (height - crop) // 2 : (height + crop) // 2,
+    #     (width - crop) // 2 : (width + crop) // 2,
+    # ]
+    image = tf.image.crop_to_bounding_box(
+        image=image,
+        offset_height=(height - crop_size) // 2,
+        offset_width=(width - crop_size) // 2,
+        target_height=crop_size,
+        target_width=crop_size,
+    )
+    image = tf.image.resize(
+        image,
+        size=resolution,
+        antialias=True,
+        method=tf.image.ResizeMethod.BICUBIC,
+    )
+    return tf.clip_by_value(image / 255.0, 0.0, 1.0)
+
+
+def process_image(
+    data: dict[str, tf.Tensor],
+    resolution: list[int],
+    input_dtype: tf.DType = tf.float32,
+) -> tf.Tensor:
+    image = data.get("image")
+    image = crop_and_resize(image=image, resolution=resolution)
+    # image = normalize_to_neg_one_to_one(image)
+    image = tf.image.convert_image_dtype(image, input_dtype)
+    return image
+
+
+def upsample2d(x, scale: Union[int, Tuple[int, int]], method: str = "bilinear"):
+    b, h, w, c = x.shape
+
+    if isinstance(scale, int):
+        h_out, w_out = scale * h, scale * w
+    elif len(scale) == 2:
+        h_out, w_out = scale[0] * h, scale[1] * w
+    else:
+        raise ValueError("scale argument should be either int" "or Tuple[int, int]")
+
+    return jax.image.resize(x, shape=(b, h_out, w_out, c), method=method)