Skip to content

Commit

Permalink
Update rand_augment init
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka committed Jan 2, 2025
1 parent 08a64df commit 288f0f2
Showing 1 changed file with 34 additions and 64 deletions.
98 changes: 34 additions & 64 deletions keras/src/layers/preprocessing/image_preprocessing/rand_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class RandAugment(BaseImagePreprocessingLayer):
"""

_AUGMENT_LAYERS = ["random_shear", "random_translation", "random_rotation", "random_brightness",
"random_color_degeneration", "random_contrast", "random_sharpness",
"random_posterization", "solarization", "auto_contrast", "equalization"]
_AUGMENT_LAYERS = ["Identity", "random_shear", "random_translation", "random_rotation",
"random_brightness", "random_color_degeneration", "random_contrast",
"random_sharpness", "random_posterization", "solarization", "auto_contrast", "equalization"]

def __init__(
self,
Expand All @@ -54,91 +54,68 @@ def __init__(

augmentation_space = self._augmentation_space(self.num_magnitude_bins)

op_index = self.backend.random.uniform((2,),
minval=0, maxval=len(augmentation_space["ShearX"]),
seed=self.seed)
random_indices = self.backend.random.randint([11], 0, self.num_magnitude_bins, seed=self.seed)
self.random_shear = layers.RandomShear(
x_factor=float(augmentation_space["ShearX"][int(op_index[0])]),
y_factor=float(augmentation_space["ShearY"][int(op_index[1])]),
x_factor=float(augmentation_space["ShearX"][int(random_indices[0])]),
y_factor=float(augmentation_space["ShearY"][int(random_indices[1])]),
seed=self.seed,
data_format=data_format,
)

op_index = self.backend.random.uniform((2,),
minval=0, maxval=len(augmentation_space["TranslateX"]),
seed=self.seed)
self.random_translation = layers.RandomTranslation(
height_factor=float(augmentation_space["TranslateX"][int(op_index[0])]),
width_factor=float(augmentation_space["TranslateY"][int(op_index[1])]),
height_factor=float(augmentation_space["TranslateX"][int(random_indices[2])]),
width_factor=float(augmentation_space["TranslateY"][int(random_indices[3])]),
seed=self.seed,
data_format=data_format,
)

op_index = self.backend.random.uniform((1,),
minval=0, maxval=len(augmentation_space["Rotate"]),
seed=self.seed)
self.random_rotation = layers.RandomRotation(
factor=float(augmentation_space["Rotate"][int(op_index[0])]),
factor=float(augmentation_space["Rotate"][int(random_indices[4])]),
seed=self.seed,
data_format=data_format,
)

op_index = self.backend.random.uniform((1,),
minval=0, maxval=len(augmentation_space["Brightness"]),
seed=self.seed)
self.random_brightness = layers.RandomBrightness(
factor=float(augmentation_space["Brightness"][int(op_index[0])]),
factor=float(augmentation_space["Brightness"][int(random_indices[5])]),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

op_index = self.backend.random.uniform((1,),
minval=0, maxval=len(augmentation_space["Color"]),
seed=self.seed)
self.random_color_degeneration = layers.RandomColorDegeneration(
factor=float(augmentation_space["Color"][int(op_index[0])]),
factor=float(augmentation_space["Color"][int(random_indices[6])]),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

op_index = self.backend.random.uniform((1,),
minval=0, maxval=len(augmentation_space["Contrast"]),
seed=self.seed)
self.random_contrast = layers.RandomContrast(
factor=float(augmentation_space["Contrast"][int(op_index[0])]),
factor=float(augmentation_space["Contrast"][int(random_indices[7])]),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

op_index = self.backend.random.uniform((1,),
minval=0, maxval=len(augmentation_space["Sharpness"]),
seed=self.seed)
self.random_sharpness = layers.RandomSharpness(
factor=float(augmentation_space["Sharpness"][int(op_index[0])]),
factor=float(augmentation_space["Sharpness"][int(random_indices[8])]),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

op_index = self.backend.random.uniform((1,),
minval=0, maxval=len(augmentation_space["Posterize"]),
seed=self.seed)
self.random_posterization = layers.RandomPosterization(
factor=int(augmentation_space["Posterize"][int(op_index[0])]),
self.solarization = layers.Solarization(
addition_factor=int(augmentation_space["Solarize"][int(random_indices[9])]),
threshold_factor=int(augmentation_space["Solarize"][int(random_indices[10])]),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

op_index = self.backend.random.uniform((2,),
minval=0, maxval=len(augmentation_space["Solarize"]),
seed=self.seed)
self.solarization = layers.Solarization(
addition_factor=int(augmentation_space["Solarize"][int(op_index[0])]),
threshold_factor=int(augmentation_space["Solarize"][int(op_index[1])]),
random_indices = self.backend.random.uniform((1,),
minval=0, maxval=len(augmentation_space["Posterize"]),
seed=self.seed)
self.random_posterization = layers.RandomPosterization(
factor=int(augmentation_space["Posterize"][int(random_indices[0])]),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
Expand All @@ -156,9 +133,6 @@ def __init__(

def _augmentation_space(self, num_bins):
return {
"Identity": self.backend.convert_to_tensor(
0.0, dtype=self.compute_dtype
),
"ShearX": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"ShearY": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"TranslateX": self.backend.numpy.linspace(-1, 1, num_bins),
Expand All @@ -168,11 +142,8 @@ def _augmentation_space(self, num_bins):
"Color": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"Contrast": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"Sharpness": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"Posterize": 8. - (self.backend.numpy.arange(num_bins, dtype='float32') / ((num_bins - 1.) / 4)),
"Solarize": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"Equalize": self.backend.convert_to_tensor(
0.0, dtype=self.compute_dtype
),
"Posterize": 8. - (self.backend.numpy.arange(num_bins, dtype='float32') / ((num_bins - 1.) / 4)),
}

def get_random_transformation(self, data, training=True, seed=None):
Expand All @@ -181,22 +152,21 @@ def get_random_transformation(self, data, training=True, seed=None):

if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")
self.random_shear.backend.set_backend("tensorflow")
self.random_translation.backend.set_backend("tensorflow")
self.random_rotation.backend.set_backend("tensorflow")
self.random_brightness.backend.set_backend("tensorflow")
self.random_color_degeneration.backend.set_backend("tensorflow")
self.random_contrast.backend.set_backend("tensorflow")
self.random_sharpness.backend.set_backend("tensorflow")
self.random_posterization.backend.set_backend("tensorflow")
self.solarization.backend.set_backend("tensorflow")
self.auto_contrast.backend.set_backend("tensorflow")
self.equalization.backend.set_backend("tensorflow")

for layer_name in self._AUGMENT_LAYERS:
if layer_name == "Identity":
continue
augmentation_layer = getattr(self, layer_name)
augmentation_layer.backend.set_backend("tensorflow")

transformation = {}
random_index = self.backend.random.randint((1,), 0, len(self._AUGMENT_LAYERS), seed=self.seed)
for layer_idx in random_index:
random_indices = self.backend.random.shuffle(
self.backend.numpy.arange(len(self._AUGMENT_LAYERS)),
seed=self.seed)[:self.num_ops]
for layer_idx in random_indices:
layer_name = self._AUGMENT_LAYERS[layer_idx]
if layer_name == "Identity":
continue
augmentation_layer = getattr(self, layer_name)
transformation[layer_name] = augmentation_layer.get_random_transformation(data,
training=training,
Expand Down

0 comments on commit 288f0f2

Please sign in to comment.