Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data-efficient GANs with Adaptive Discriminator Augmentation to keras 3.0 (Tensorflow backend only) #2035

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 63 additions & 50 deletions examples/generative/gan_ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Data-efficient GANs with Adaptive Discriminator Augmentation
Author: [András Béres](https://www.linkedin.com/in/andras-beres-789190210)
Date created: 2021/10/28
Last modified: 2021/10/28
Last modified: 2025/01/23
Description: Generating images from limited data using the Caltech Birds dataset.
Accelerator: GPU
"""
Expand Down Expand Up @@ -62,12 +62,17 @@ class of generative deep learning models, commonly used for image generation. Th
## Setup
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import ops
from keras import layers

"""
## Hyperparameterers
Expand Down Expand Up @@ -115,46 +120,47 @@ class of generative deep learning models, commonly used for image generation. Th


def round_to_int(float_value):
return tf.cast(tf.math.round(float_value), dtype=tf.int32)
return ops.cast(ops.round(float_value), "int32")


def preprocess_image(data):
# unnormalize bounding box coordinates
height = tf.cast(tf.shape(data["image"])[0], dtype=tf.float32)
width = tf.cast(tf.shape(data["image"])[1], dtype=tf.float32)
bounding_box = data["bbox"] * tf.stack([height, width, height, width])
height = ops.cast(ops.shape(data["image"])[0], "float32")
width = ops.cast(ops.shape(data["image"])[1], "float32")
bounding_box = data["bbox"] * ops.stack([height, width, height, width])

# calculate center and length of longer side, add padding
target_center_y = 0.5 * (bounding_box[0] + bounding_box[2])
target_center_x = 0.5 * (bounding_box[1] + bounding_box[3])
target_size = tf.maximum(
target_size = ops.maximum(
(1.0 + padding) * (bounding_box[2] - bounding_box[0]),
(1.0 + padding) * (bounding_box[3] - bounding_box[1]),
)

# modify crop size to fit into image
target_height = tf.reduce_min(
target_height = ops.min(
[target_size, 2.0 * target_center_y, 2.0 * (height - target_center_y)]
)
target_width = tf.reduce_min(
target_width = ops.min(
[target_size, 2.0 * target_center_x, 2.0 * (width - target_center_x)]
)

# crop image
image = tf.image.crop_to_bounding_box(
# crop image, `ops.image.crop_images` only works with non-tensor croppings
image = ops.slice(
data["image"],
offset_height=round_to_int(target_center_y - 0.5 * target_height),
offset_width=round_to_int(target_center_x - 0.5 * target_width),
target_height=round_to_int(target_height),
target_width=round_to_int(target_width),
start_indices=(
round_to_int(target_center_y - 0.5 * target_height),
round_to_int(target_center_x - 0.5 * target_width),
0,
),
shape=(round_to_int(target_height), round_to_int(target_width), 3),
)

# resize and clip
# for image downsampling, area interpolation is the preferred method
image = tf.image.resize(
image, size=[image_size, image_size], method=tf.image.ResizeMethod.AREA
)
return tf.clip_by_value(image / 255.0, 0.0, 1.0)
image = ops.cast(image, "float32")
image = ops.image.resize(image, [image_size, image_size])

return ops.clip(image / 255.0, 0.0, 1.0)


def prepare_dataset(split):
Expand Down Expand Up @@ -231,8 +237,10 @@ def __init__(self, name="kid", **kwargs):
)

def polynomial_kernel(self, features_1, features_2):
feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype=tf.float32)
return (features_1 @ tf.transpose(features_2) / feature_dimensions + 1.0) ** 3.0
feature_dimensions = ops.cast(ops.shape(features_1)[1], "float32")
return (
features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
) ** 3.0

def update_state(self, real_images, generated_images, sample_weight=None):
real_features = self.encoder(real_images, training=False)
Expand All @@ -246,15 +254,15 @@ def update_state(self, real_images, generated_images, sample_weight=None):
kernel_cross = self.polynomial_kernel(real_features, generated_features)

# estimate the squared maximum mean discrepancy using the average kernel values
batch_size = tf.shape(real_features)[0]
batch_size_f = tf.cast(batch_size, dtype=tf.float32)
mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / (
batch_size = ops.shape(real_features)[0]
batch_size_f = ops.cast(batch_size, "float32")
mean_kernel_real = ops.sum(kernel_real * (1.0 - ops.eye(batch_size))) / (
batch_size_f * (batch_size_f - 1.0)
)
mean_kernel_generated = tf.reduce_sum(
kernel_generated * (1.0 - tf.eye(batch_size))
mean_kernel_generated = ops.sum(
kernel_generated * (1.0 - ops.eye(batch_size))
) / (batch_size_f * (batch_size_f - 1.0))
mean_kernel_cross = tf.reduce_mean(kernel_cross)
mean_kernel_cross = ops.mean(kernel_cross)
kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross

# update the average KID estimate
Expand Down Expand Up @@ -299,7 +307,7 @@ def reset_state(self):
# "hard sigmoid", useful for binary accuracy calculation from logits
def step(values):
# negative values -> 0.0, positive values -> 1.0
return 0.5 * (1.0 + tf.sign(values))
return 0.5 * (1.0 + ops.sign(values))


# augments images with a probability that is dynamically updated during training
Expand All @@ -308,7 +316,8 @@ def __init__(self):
super().__init__()

# stores the current probability of an image being augmented
self.probability = tf.Variable(0.0)
self.probability = keras.Variable(0.0)
self.seed_generator = keras.random.SeedGenerator(42)

# the corresponding augmentation names from the paper are shown above each layer
# the authors show (see figure 4), that the blitting and geometric augmentations
Expand Down Expand Up @@ -336,28 +345,26 @@ def __init__(self):

def call(self, images, training):
if training:
augmented_images = self.augmenter(images, training)
augmented_images = self.augmenter(images, training=training)

# during training either the original or the augmented images are selected
# based on self.probability
augmentation_values = tf.random.uniform(
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
augmentation_values = keras.random.uniform(
shape=(batch_size, 1, 1, 1), seed=self.seed_generator
)
augmentation_bools = tf.math.less(augmentation_values, self.probability)
augmentation_bools = ops.less(augmentation_values, self.probability)

images = tf.where(augmentation_bools, augmented_images, images)
images = ops.where(augmentation_bools, augmented_images, images)
return images

def update(self, real_logits):
current_accuracy = tf.reduce_mean(step(real_logits))
current_accuracy = ops.mean(step(real_logits))

# the augmentation probability is updated based on the discriminator's
# accuracy on real images
accuracy_error = current_accuracy - target_accuracy
self.probability.assign(
tf.clip_by_value(
self.probability + accuracy_error / integration_steps, 0.0, 1.0
)
ops.clip(self.probability + accuracy_error / integration_steps, 0.0, 1.0)
)


Expand Down Expand Up @@ -445,13 +452,17 @@ class GAN_ADA(keras.Model):
def __init__(self):
super().__init__()

self.seed_generator = keras.random.SeedGenerator(seed=42)
self.augmenter = AdaptiveAugmenter()
self.generator = get_generator()
self.ema_generator = keras.models.clone_model(self.generator)
self.discriminator = get_discriminator()

self.generator.summary()
self.discriminator.summary()
# we have created all layers at this point, so we can mark the model
# as having been built
self.built = True

def compile(self, generator_optimizer, discriminator_optimizer, **kwargs):
super().compile(**kwargs)
Expand Down Expand Up @@ -479,32 +490,34 @@ def metrics(self):
]

def generate(self, batch_size, training):
latent_samples = tf.random.normal(shape=(batch_size, noise_size))
latent_samples = keras.random.normal(
shape=(batch_size, noise_size), seed=self.seed_generator
)
# use ema_generator during inference
if training:
generated_images = self.generator(latent_samples, training)
generated_images = self.generator(latent_samples, training=training)
else:
generated_images = self.ema_generator(latent_samples, training)
generated_images = self.ema_generator(latent_samples, training=training)
return generated_images

def adversarial_loss(self, real_logits, generated_logits):
# this is usually called the non-saturating GAN loss

real_labels = tf.ones(shape=(batch_size, 1))
generated_labels = tf.zeros(shape=(batch_size, 1))
real_labels = ops.ones(shape=(batch_size, 1))
generated_labels = ops.zeros(shape=(batch_size, 1))

# the generator tries to produce images that the discriminator considers as real
generator_loss = keras.losses.binary_crossentropy(
real_labels, generated_logits, from_logits=True
)
# the discriminator tries to determine if images are real or generated
discriminator_loss = keras.losses.binary_crossentropy(
tf.concat([real_labels, generated_labels], axis=0),
tf.concat([real_logits, generated_logits], axis=0),
ops.concatenate([real_labels, generated_labels], axis=0),
ops.concatenate([real_logits, generated_logits], axis=0),
from_logits=True,
)

return tf.reduce_mean(generator_loss), tf.reduce_mean(discriminator_loss)
return ops.mean(generator_loss), ops.mean(discriminator_loss)

def train_step(self, real_images):
real_images = self.augmenter(real_images, training=True)
Expand Down Expand Up @@ -604,8 +617,8 @@ def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5)
)

# save the best model based on the validation KID metric
checkpoint_path = "gan_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path = "gan_model.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
monitor="val_kid",
Expand Down
Loading