Skip to content

Commit 08ca8d6

Browse files
authored
Add VectorizedBaseImageAugmentation layer (#1373)
* implement vectorized base image augmentation layer * Implement vectorized RandomContrast layer * KPL performance * Random contrast vectorized * Vectorized contrast * Fix vectorized base layer * Fix vectorized base layer * Add vectorized grayscale layer * Remove random contrast * Remove random contrast * test_preserves_ragged_status_Grayscale * test_preserves_ragged_status_Grayscale * Fix * Fix masks * Fix masks * rename to 'batched' * Fix docstrings * Fix docstrings * Remove ragged method * Begin ragged image support * Begin ragged image support * Begin ragged image support * Begin ragged image support * Begin ragged image support * Performance benchmark * Reformat * Vectorized grayscale * Fix ragged test case * Fix ragged test case * Fix ragged test case * Fix ragged test case * Fix ragged test case
1 parent f289d90 commit 08ca8d6

6 files changed

+903
-40
lines changed

benchmarks/vectorized_grayscale.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import time
15+
16+
import matplotlib.pyplot as plt
17+
import tensorflow as tf
18+
import tensorflow.keras as keras
19+
20+
from keras_cv.layers import Grayscale
21+
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
22+
BaseImageAugmentationLayer,
23+
)
24+
25+
26+
class OldGrayscale(BaseImageAugmentationLayer):
27+
"""Grayscale is a preprocessing layer that transforms RGB images to Grayscale images.
28+
Input images should have values in the range of [0, 255].
29+
Input shape:
30+
3D (unbatched) or 4D (batched) tensor with shape:
31+
`(..., height, width, channels)`, in `"channels_last"` format
32+
Output shape:
33+
3D (unbatched) or 4D (batched) tensor with shape:
34+
`(..., height, width, channels)`, in `"channels_last"` format
35+
Args:
36+
output_channels.
37+
Number color channels present in the output image.
38+
The output_channels can be 1 or 3. RGB image with shape
39+
(..., height, width, 3) will have the following shapes
40+
after the `Grayscale` operation:
41+
a. (..., height, width, 1) if output_channels = 1
42+
b. (..., height, width, 3) if output_channels = 3.
43+
Usage:
44+
```python
45+
(images, labels), _ = tf.keras.datasets.cifar10.load_data()
46+
to_grayscale = keras_cv.layers.preprocessing.Grayscale()
47+
augmented_images = to_grayscale(images)
48+
```
49+
"""
50+
51+
def __init__(self, output_channels=1, **kwargs):
52+
super().__init__(**kwargs)
53+
self.output_channels = output_channels
54+
# This layer may raise an error when running on GPU using auto_vectorize
55+
self.auto_vectorize = False
56+
57+
def compute_image_signature(self, images):
58+
# required because of the `output_channels` argument
59+
if isinstance(images, tf.RaggedTensor):
60+
ragged_spec = tf.RaggedTensorSpec(
61+
shape=images.shape[1:3] + [self.output_channels],
62+
ragged_rank=1,
63+
dtype=self.compute_dtype,
64+
)
65+
return ragged_spec
66+
return tf.TensorSpec(
67+
images.shape[1:3] + [self.output_channels], self.compute_dtype
68+
)
69+
70+
def _check_input_params(self, output_channels):
71+
if output_channels not in [1, 3]:
72+
raise ValueError(
73+
"Received invalid argument output_channels. "
74+
f"output_channels must be in 1 or 3. Got {output_channels}"
75+
)
76+
self.output_channels = output_channels
77+
78+
def augment_image(self, image, transformation=None, **kwargs):
79+
grayscale = tf.image.rgb_to_grayscale(image)
80+
if self.output_channels == 1:
81+
return grayscale
82+
elif self.output_channels == 3:
83+
return tf.image.grayscale_to_rgb(grayscale)
84+
else:
85+
raise ValueError("Unsupported value for `output_channels`.")
86+
87+
def augment_bounding_boxes(self, bounding_boxes, **kwargs):
88+
return bounding_boxes
89+
90+
def augment_label(self, label, transformation=None, **kwargs):
91+
return label
92+
93+
def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs):
94+
return segmentation_mask
95+
96+
def get_config(self):
97+
config = {
98+
"output_channels": self.output_channels,
99+
}
100+
base_config = super().get_config()
101+
return dict(list(base_config.items()) + list(config.items()))
102+
103+
104+
(x_train, _), _ = keras.datasets.cifar10.load_data()
105+
x_train = x_train.astype(float)
106+
107+
x_train.shape
108+
109+
110+
images = []
111+
112+
num_images = [1000, 2000, 5000, 10000]
113+
114+
results = {}
115+
116+
for aug in [Grayscale, OldGrayscale]:
117+
c = aug.__name__
118+
119+
layer = aug()
120+
121+
runtimes = []
122+
print(f"Timing {c}")
123+
124+
for n_images in num_images:
125+
# warmup
126+
layer(x_train[:n_images])
127+
128+
t0 = time.time()
129+
r1 = layer(x_train[:n_images])
130+
t1 = time.time()
131+
runtimes.append(t1 - t0)
132+
print(f"Runtime for {c}, n_images={n_images}: {t1-t0}")
133+
134+
results[c] = runtimes
135+
136+
c = aug.__name__ + " Graph Mode"
137+
138+
layer = aug()
139+
140+
@tf.function()
141+
def apply_aug(inputs):
142+
return layer(inputs)
143+
144+
runtimes = []
145+
print(f"Timing {c}")
146+
147+
for n_images in num_images:
148+
# warmup
149+
apply_aug(x_train[:n_images])
150+
151+
t0 = time.time()
152+
r1 = apply_aug(x_train[:n_images])
153+
t1 = time.time()
154+
runtimes.append(t1 - t0)
155+
print(f"Runtime for {c}, n_images={n_images}: {t1-t0}")
156+
157+
results[c] = runtimes
158+
159+
plt.figure()
160+
for key in results:
161+
plt.plot(num_images, results[key], label=key)
162+
plt.xlabel("Number images")
163+
164+
plt.ylabel("Runtime (seconds)")
165+
plt.legend()
166+
plt.show()
167+
168+
# So we can actually see more relevant margins
169+
del results["OldGrayscale"]
170+
171+
plt.figure()
172+
for key in results:
173+
plt.plot(num_images, results[key], label=key)
174+
plt.xlabel("Number images")
175+
176+
plt.ylabel("Runtime (seconds)")
177+
plt.legend()
178+
plt.show()

keras_cv/layers/preprocessing/grayscale.py

+21-24
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
import tensorflow as tf
1616

17-
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
18-
BaseImageAugmentationLayer,
17+
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import (
18+
VectorizedBaseImageAugmentationLayer,
1919
)
2020

2121

2222
@tf.keras.utils.register_keras_serializable(package="keras_cv")
23-
class Grayscale(BaseImageAugmentationLayer):
23+
class Grayscale(VectorizedBaseImageAugmentationLayer):
2424
"""Grayscale is a preprocessing layer that transforms RGB images to Grayscale images.
2525
Input images should have values in the range of [0, 255].
2626
@@ -50,21 +50,7 @@ class Grayscale(BaseImageAugmentationLayer):
5050
def __init__(self, output_channels=1, **kwargs):
5151
super().__init__(**kwargs)
5252
self.output_channels = output_channels
53-
# This layer may raise an error when running on GPU using auto_vectorize
54-
self.auto_vectorize = False
55-
56-
def compute_image_signature(self, images):
57-
# required because of the `output_channels` argument
58-
if isinstance(images, tf.RaggedTensor):
59-
ragged_spec = tf.RaggedTensorSpec(
60-
shape=images.shape[1:3] + [self.output_channels],
61-
ragged_rank=1,
62-
dtype=self.compute_dtype,
63-
)
64-
return ragged_spec
65-
return tf.TensorSpec(
66-
images.shape[1:3] + [self.output_channels], self.compute_dtype
67-
)
53+
self._check_input_params(output_channels)
6854

6955
def _check_input_params(self, output_channels):
7056
if output_channels not in [1, 3]:
@@ -74,8 +60,19 @@ def _check_input_params(self, output_channels):
7460
)
7561
self.output_channels = output_channels
7662

77-
def augment_image(self, image, transformation=None, **kwargs):
78-
grayscale = tf.image.rgb_to_grayscale(image)
63+
def compute_ragged_image_signature(self, images):
64+
ragged_spec = tf.RaggedTensorSpec(
65+
shape=images.shape[1:3] + (self.output_channels,),
66+
ragged_rank=1,
67+
dtype=self.compute_dtype,
68+
)
69+
return ragged_spec
70+
71+
def augment_ragged_image(self, image, transformation, **kwargs):
72+
return self.augment_images(image, transformations=transformation, **kwargs)
73+
74+
def augment_images(self, images, transformations=None, **kwargs):
75+
grayscale = tf.image.rgb_to_grayscale(images)
7976
if self.output_channels == 1:
8077
return grayscale
8178
elif self.output_channels == 3:
@@ -86,11 +83,11 @@ def augment_image(self, image, transformation=None, **kwargs):
8683
def augment_bounding_boxes(self, bounding_boxes, **kwargs):
8784
return bounding_boxes
8885

89-
def augment_label(self, label, transformation=None, **kwargs):
90-
return label
86+
def augment_labels(self, labels, transformations=None, **kwargs):
87+
return labels
9188

92-
def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs):
93-
return segmentation_mask
89+
def augment_segmentation_masks(self, segmentation_masks, transformations, **kwargs):
90+
return segmentation_masks
9491

9592
def get_config(self):
9693
config = {

keras_cv/layers/preprocessing/grayscale_test.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
class GrayscaleTest(tf.test.TestCase):
2020
def test_return_shapes(self):
21-
xs = tf.ones((2, 512, 512, 3))
21+
xs = tf.ones((2, 52, 24, 3))
2222

2323
layer = preprocessing.Grayscale(
2424
output_channels=1,
@@ -30,12 +30,12 @@ def test_return_shapes(self):
3030
)
3131
xs2 = layer(xs, training=True)
3232

33-
self.assertEqual(xs1.shape, [2, 512, 512, 1])
34-
self.assertEqual(xs2.shape, [2, 512, 512, 3])
33+
self.assertEqual(xs1.shape, [2, 52, 24, 1])
34+
self.assertEqual(xs2.shape, [2, 52, 24, 3])
3535

3636
def test_in_tf_function(self):
3737
xs = tf.cast(
38-
tf.stack([2 * tf.ones((100, 100, 3)), tf.ones((100, 100, 3))], axis=0),
38+
tf.stack([2 * tf.ones((10, 10, 3)), tf.ones((10, 10, 3))], axis=0),
3939
tf.float32,
4040
)
4141

@@ -61,12 +61,12 @@ def augment(x):
6161

6262
xs2 = augment(xs)
6363

64-
self.assertEqual(xs1.shape, [2, 100, 100, 1])
65-
self.assertEqual(xs2.shape, [2, 100, 100, 3])
64+
self.assertEqual(xs1.shape, [2, 10, 10, 1])
65+
self.assertEqual(xs2.shape, [2, 10, 10, 3])
6666

6767
def test_non_square_image(self):
6868
xs = tf.cast(
69-
tf.stack([2 * tf.ones((512, 1024, 3)), tf.ones((512, 1024, 3))], axis=0),
69+
tf.stack([2 * tf.ones((52, 24, 3)), tf.ones((52, 24, 3))], axis=0),
7070
tf.float32,
7171
)
7272

@@ -80,12 +80,12 @@ def test_non_square_image(self):
8080
)
8181
xs2 = layer(xs, training=True)
8282

83-
self.assertEqual(xs1.shape, [2, 512, 1024, 1])
84-
self.assertEqual(xs2.shape, [2, 512, 1024, 3])
83+
self.assertEqual(xs1.shape, [2, 52, 24, 1])
84+
self.assertEqual(xs2.shape, [2, 52, 24, 3])
8585

8686
def test_in_single_image(self):
8787
xs = tf.cast(
88-
tf.ones((512, 512, 3)),
88+
tf.ones((52, 24, 3)),
8989
dtype=tf.float32,
9090
)
9191

@@ -99,5 +99,5 @@ def test_in_single_image(self):
9999
)
100100
xs2 = layer(xs, training=True)
101101

102-
self.assertEqual(xs1.shape, [512, 512, 1])
103-
self.assertEqual(xs2.shape, [512, 512, 3])
102+
self.assertEqual(xs1.shape, [52, 24, 1])
103+
self.assertEqual(xs2.shape, [52, 24, 3])

keras_cv/layers/preprocessing/ragged_image_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def test_preserves_ragged_status(self, layer_cls, init_args):
126126
layer = layer_cls(**init_args)
127127
inputs = tf.ragged.stack(
128128
[
129-
tf.ones((512, 512, 3)),
130-
tf.ones((600, 300, 3)),
129+
tf.ones((5, 5, 3)),
130+
tf.ones((8, 8, 3)),
131131
]
132132
)
133133
outputs = layer(inputs)
@@ -138,8 +138,8 @@ def test_converts_ragged_to_dense(self, layer_cls, init_args):
138138
layer = layer_cls(**init_args)
139139
inputs = tf.ragged.stack(
140140
[
141-
tf.ones((512, 512, 3)),
142-
tf.ones((600, 300, 3)),
141+
tf.ones((5, 5, 3)),
142+
tf.ones((8, 8, 3)),
143143
]
144144
)
145145
outputs = layer(inputs)

0 commit comments

Comments
 (0)