diff --git a/openfl-workspace/tf_2dunet/src/taskrunner.py b/openfl-workspace/tf_2dunet/src/taskrunner.py index b9a1afda58..9277e1396a 100644 --- a/openfl-workspace/tf_2dunet/src/taskrunner.py +++ b/openfl-workspace/tf_2dunet/src/taskrunner.py @@ -26,30 +26,15 @@ def __init__(self, **kwargs): """ super().__init__(**kwargs) - self.create_model(**kwargs) + self.model = self.build_model(self.data_loader.get_feature_shape(), use_upsampling=True, **kwargs) + self.model.summary(print_fn=self.logger.info, line_length=120) self.initialize_tensorkeys_for_functions() - def create_model(self, - training_smoothing=32.0, - validation_smoothing=1.0, - **kwargs): - """Create the TensorFlow 2D U-Net model. - - Args: - training_smoothing (float): (Default=32.0) - validation_smoothing (float): (Default=1.0) - **kwargs: Additional parameters to pass to the function - """ - self.input_shape = self.data_loader.get_feature_shape() - self.model = self.define_model(self.input_shape, use_upsampling=True, **kwargs) - self.model.summary(print_fn=self.logger.info, line_length=120) - - def define_model(self, input_shape, + def build_model(self, input_shape, use_upsampling=False, n_cl_out=1, dropout=0.2, - print_summary=True, activation_function='relu', seed=0xFEEDFACE, depth=5, @@ -64,15 +49,12 @@ def define_model(self, input_shape, use_upsampling (bool): True = use bilinear interpolation; False = use transposed convolution (Default=False) n_cl_out (int): Number of channels in input layer (Default=1) - dropout (float): Dropout percentage (Default=0.2) - print_summary (bool): True = print the model summary (Default = True) - activation_function: The activation function to use after convolutional - layers (Default='relu') + dropout (float): Dropout percentage (Default=0.2)(Default = True) + activation_function: The activation function to use after convolutional layers (Default='relu') seed: random seed (Default=0xFEEDFACE) depth (int): Number of max pooling layers in encoder (Default=5) dropout_at: Layers to perform dropout after (Default=[2,3]) - initial_filters (int): Number of filters in first convolutional - layer (Default=32) + initial_filters (int): Number of filters in first convolutional layer (Default=32) batch_norm (bool): True = use batch normalization (Default=True) **kwargs: Additional parameters to pass to the function @@ -89,7 +71,6 @@ def define_model(self, input_shape, params = { 'activation': activation, - 'data_format': data_format, 'kernel_initializer': keras.initializers.he_uniform(seed=seed), 'kernel_size': (3, 3), 'padding': 'same', @@ -123,11 +104,11 @@ def define_model(self, input_shape, name=f'up{depth + i + 1}', size=(2, 2))(net) else: up = keras.layers.Conv2DTranspose( - name='transConv6', filters=filters, data_format=data_format, + name='transConv6', filters=filters, kernel_size=(2, 2), strides=(2, 2), padding='same')(net) net = keras.layers.concatenate( [up, convb_layers[f'conv{depth - i - 1}b']], - axis=concat_axis + axis=-1 ) net = keras.layers.Conv2D( name=f'conv{depth + i + 1}a', @@ -137,7 +118,7 @@ def define_model(self, input_shape, filters=filters, **params)(net) filters //= 2 net = keras.layers.Conv2D(name='Mask', filters=n_cl_out, - kernel_size=(1, 1), data_format=data_format, + kernel_size=(1, 1), activation='sigmoid')(net) model = keras.models.Model(inputs=[inputs], outputs=[net]) @@ -146,12 +127,12 @@ def define_model(self, input_shape, model.compile( loss=self.dice_coef_loss, optimizer=self.optimizer, - metrics=["acc"], + metrics=["acc"] ) return model - def dice_coef_loss(self, y_true, y_pred): + def dice_coef_loss(self, y_true, y_pred, smooth=1.0): """Dice coefficient loss. Calculate the -log(Dice Coefficient) loss @@ -164,7 +145,7 @@ def dice_coef_loss(self, y_true, y_pred): float: -log(Dice cofficient) metric """ intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2, 3)) - smooth=1.0 + term1 = -tf.math.log(tf.constant(2.0) * intersection + smooth) term2 = tf.math.log(tf.reduce_sum(y_true, axis=(1, 2, 3)) + tf.reduce_sum(y_pred, axis=(1, 2, 3)) + smooth) @@ -175,13 +156,3 @@ def dice_coef_loss(self, y_true, y_pred): loss = term1 + term2 return loss - - -CHANNEL_LAST = True -if CHANNEL_LAST: - concat_axis = -1 - data_format = 'channels_last' -else: - concat_axis = 1 - data_format = 'channels_first' -