diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py index eee6f31b8e4..e2f1962579d 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py @@ -112,6 +112,19 @@ def __init__( seed=self.seed, ) + def build(self, input_shape): + if self.brightness_factor is not None: + self.random_brightness.build(input_shape) + + if self.contrast_factor is not None: + self.random_contrast.build(input_shape) + + if self.saturation_factor is not None: + self.random_saturation.build(input_shape) + + if self.hue_factor is not None: + self.random_hue.build(input_shape) + def transform_images(self, images, transformation, training=True): if training: if backend_utils.in_tf_graph():