Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Semi-supervision and domain adaptation with AdaMatch example for Keras v3 #1785

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 64 additions & 43 deletions examples/vision/adamatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Semi-supervision and domain adaptation with AdaMatch
Author: [Sayak Paul](https://twitter.com/RisingSayak)
Date created: 2021/06/19
Last modified: 2021/06/19
Last modified: 2024/03/10
Description: Unifying semi-supervised learning and unsupervised domain adaptation with AdaMatch.
Accelerator: GPU
"""
Expand All @@ -18,7 +18,7 @@
(UDA) under one framework. It thereby provides a way to perform semi-supervised domain
adaptation (SSDA).

This example requires TensorFlow 2.5 or higher, as well as TensorFlow Models, which can
This example requires TensorFlow 2.15 or higher, as well as TensorFlow Models, which can
be installed using the following command:
"""

Expand Down Expand Up @@ -62,15 +62,17 @@
## Setup
"""

import tensorflow as tf
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

tf.random.set_seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to do keras.utils.set_random_seed(42) to replace this line.

import tensorflow as tf

import numpy as np

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import regularizers
import keras
from keras import layers
from keras import regularizers
from keras_cv.layers import RandAugment

import tensorflow_datasets as tfds
Expand All @@ -88,11 +90,11 @@
) = keras.datasets.mnist.load_data()

# Add a channel dimension
mnist_x_train = tf.expand_dims(mnist_x_train, -1)
mnist_x_test = tf.expand_dims(mnist_x_test, -1)
mnist_x_train = keras.ops.expand_dims(mnist_x_train, -1)
mnist_x_test = keras.ops.expand_dims(mnist_x_test, -1)

# Convert the labels to one-hot encoded vectors
mnist_y_train = tf.one_hot(mnist_y_train, 10).numpy()
mnist_y_train = keras.ops.one_hot(mnist_y_train, 10).numpy()

# SVHN
svhn_train, svhn_test = tfds.load(
Expand Down Expand Up @@ -134,26 +136,26 @@


def weak_augment(image, source=True):
if image.dtype != tf.float32:
image = tf.cast(image, tf.float32)
if image.dtype != "float32":
image = keras.ops.cast(image, dtype="float32")

# MNIST images are grayscale, this is why we first convert them to
# RGB images.
if source:
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
image = tf.tile(image, [1, 1, 3])
image = keras.ops.tile(image, [1, 1, 3])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_crop(image, (RESIZE_TO, RESIZE_TO, 3))
return image


def strong_augment(image, source=True):
if image.dtype != tf.float32:
image = tf.cast(image, tf.float32)
if image.dtype != "float32":
image = keras.ops.cast(image, dtype="float32")

if source:
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
image = tf.tile(image, [1, 1, 3])
image = keras.ops.tile(image, [1, 1, 3])
image = augmenter(image)
return image

Expand Down Expand Up @@ -217,16 +219,16 @@ def compute_loss_source(source_labels, logits_source_w, logits_source_s):

def compute_loss_target(target_pseudo_labels_w, logits_target_s, mask):
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True, reduction="none")
target_pseudo_labels_w = tf.stop_gradient(target_pseudo_labels_w)
target_pseudo_labels_w = keras.ops.stop_gradient(target_pseudo_labels_w)
# For calculating loss for the target samples, we treat the pseudo labels
# as the ground-truth. These are not considered during backpropagation
# which is a standard SSL practice.
target_loss = loss_func(target_pseudo_labels_w, logits_target_s)

# More on `mask` later.
mask = tf.cast(mask, target_loss.dtype)
mask = keras.ops.cast(mask, target_loss.dtype)
target_loss *= mask
return tf.reduce_mean(target_loss, 0)
return keras.ops.mean(target_loss, 0)


"""
Expand Down Expand Up @@ -262,9 +264,16 @@ def __init__(self, model, total_steps, tau=0.9):
super().__init__()
self.model = model
self.tau = tau # Denotes the confidence threshold
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
self.loss_tracker = keras.metrics.Mean(name="loss")
self.total_steps = total_steps
self.current_step = tf.Variable(0, dtype="int64")
self.current_step = self.add_weight(
shape=[
1,
],
dtype="int64",
initializer="zeros",
)
self.seed_generator = keras.random.SeedGenerator(1337)

@property
def metrics(self):
Expand All @@ -274,9 +283,13 @@ def metrics(self):
# loss contributed by the target unlabeled samples. More
# on this in the text.
def compute_mu(self):
pi = tf.constant(np.pi, dtype="float32")
step = tf.cast(self.current_step, dtype="float32")
return 0.5 - tf.cos(tf.math.minimum(pi, (2 * pi * step) / self.total_steps)) / 2
pi = keras.ops.array(np.pi, dtype="float32")
step = keras.ops.cast(self.current_step, dtype="float32")
return (
0.5
- keras.ops.cos(keras.ops.minimum(pi, (2 * pi * step) / self.total_steps))
/ 2
)

def train_step(self, data):
## Unpack and organize the data ##
Expand All @@ -287,11 +300,15 @@ def train_step(self, data):
(target_s, _),
) = target_ds # Notice that we are NOT using any labels here.

combined_images = tf.concat([source_w, source_s, target_w, target_s], 0)
combined_source = tf.concat([source_w, source_s], 0)
combined_images = keras.ops.concatenate(
[source_w, source_s, target_w, target_s], 0
)
combined_source = keras.ops.concatenate([source_w, source_s], 0)

total_source = tf.shape(combined_source)[0]
total_target = tf.shape(tf.concat([target_w, target_s], 0))[0]
total_source = keras.ops.shape(combined_source)[0]
total_target = keras.ops.shape(keras.ops.concatenate([target_w, target_s], 0))[
0
]

with tf.GradientTape() as tape:
## Forward passes ##
Expand All @@ -302,42 +319,46 @@ def train_step(self, data):
z_prime_source = combined_logits[:total_source]

## 1. Random logit interpolation for the source images ##
lambd = tf.random.uniform((total_source, 10), 0, 1)
lambd = keras.random.uniform(
(total_source, 10), 0, 1, seed=self.seed_generator
)
final_source_logits = (lambd * z_prime_source) + (
(1 - lambd) * z_d_prime_source
)

## 2. Distribution alignment (only consider weakly augmented images) ##
# Compute softmax for logits of the WEAKLY augmented SOURCE images.
y_hat_source_w = tf.nn.softmax(final_source_logits[: tf.shape(source_w)[0]])
y_hat_source_w = keras.ops.softmax(
final_source_logits[: keras.ops.shape(source_w)[0]]
)

# Extract logits for the WEAKLY augmented TARGET images and compute softmax.
logits_target = combined_logits[total_source:]
logits_target_w = logits_target[: tf.shape(target_w)[0]]
y_hat_target_w = tf.nn.softmax(logits_target_w)
logits_target_w = logits_target[: keras.ops.shape(target_w)[0]]
y_hat_target_w = keras.ops.softmax(logits_target_w)

# Align the target label distribution to that of the source.
expectation_ratio = tf.reduce_mean(y_hat_source_w) / tf.reduce_mean(
expectation_ratio = keras.ops.mean(y_hat_source_w) / keras.ops.mean(
y_hat_target_w
)
y_tilde_target_w = tf.math.l2_normalize(
y_hat_target_w * expectation_ratio, 1
y_tilde_target_w = keras.ops.normalize(
y_hat_target_w * expectation_ratio, axis=1, order=2
)

## 3. Relative confidence thresholding ##
row_wise_max = tf.reduce_max(y_hat_source_w, axis=-1)
final_sum = tf.reduce_mean(row_wise_max, 0)
row_wise_max = keras.ops.amax(y_hat_source_w, axis=-1)
final_sum = keras.ops.mean(row_wise_max, 0)
c_tau = self.tau * final_sum
mask = tf.reduce_max(y_tilde_target_w, axis=-1) >= c_tau
mask = keras.ops.amax(y_tilde_target_w, axis=-1) >= c_tau

## Compute losses (pay attention to the indexing) ##
source_loss = compute_loss_source(
source_labels,
final_source_logits[: tf.shape(source_w)[0]],
final_source_logits[tf.shape(source_w)[0] :],
final_source_logits[: keras.ops.shape(source_w)[0]],
final_source_logits[keras.ops.shape(source_w)[0] :],
)
target_loss = compute_loss_target(
y_tilde_target_w, logits_target[tf.shape(target_w)[0] :], mask
y_tilde_target_w, logits_target[keras.ops.shape(target_w)[0] :], mask
)

t = self.compute_mu() # Compute weight for the target loss
Expand Down Expand Up @@ -556,7 +577,7 @@ def get_network(image_size=32, num_classes=10):

# Compile the AdaMatch model to yield accuracy.
adamatch_trained_model = adamatch_trainer.model
adamatch_trained_model.compile(metrics=keras.metrics.SparseCategoricalAccuracy())
adamatch_trained_model.compile(metrics=[keras.metrics.SparseCategoricalAccuracy()])

# Score on the target test set.
svhn_test = svhn_test.batch(TARGET_BATCH_SIZE).prefetch(AUTO)
Expand All @@ -575,7 +596,7 @@ def get_network(image_size=32, num_classes=10):
# Utility function for preprocessing the source test set.
def prepare_test_ds_source(image, label):
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
image = tf.tile(image, [1, 1, 3])
image = keras.ops.tile(image, [1, 1, 3])
return image, label


Expand Down
Loading