Skip to content

Commit

Permalink
code changes
Browse files Browse the repository at this point in the history
Signed-off-by: yes <[email protected]>
  • Loading branch information
tanwarsh committed Dec 12, 2024
1 parent 4d1d98a commit f2839c2
Showing 1 changed file with 12 additions and 41 deletions.
53 changes: 12 additions & 41 deletions openfl-workspace/tf_2dunet/src/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,15 @@ def __init__(self, **kwargs):
"""
super().__init__(**kwargs)

self.create_model(**kwargs)
self.model = self.build_model(self.data_loader.get_feature_shape(), use_upsampling=True, **kwargs)
self.model.summary(print_fn=self.logger.info, line_length=120)
self.initialize_tensorkeys_for_functions()

def create_model(self,
training_smoothing=32.0,
validation_smoothing=1.0,
**kwargs):
"""Create the TensorFlow 2D U-Net model.
Args:
training_smoothing (float): (Default=32.0)
validation_smoothing (float): (Default=1.0)
**kwargs: Additional parameters to pass to the function

"""
self.input_shape = self.data_loader.get_feature_shape()
self.model = self.define_model(self.input_shape, use_upsampling=True, **kwargs)
self.model.summary(print_fn=self.logger.info, line_length=120)

def define_model(self, input_shape,
def build_model(self, input_shape,
use_upsampling=False,
n_cl_out=1,
dropout=0.2,
print_summary=True,
activation_function='relu',
seed=0xFEEDFACE,
depth=5,
Expand All @@ -64,15 +49,12 @@ def define_model(self, input_shape,
use_upsampling (bool): True = use bilinear interpolation;
False = use transposed convolution (Default=False)
n_cl_out (int): Number of channels in input layer (Default=1)
dropout (float): Dropout percentage (Default=0.2)
print_summary (bool): True = print the model summary (Default = True)
activation_function: The activation function to use after convolutional
layers (Default='relu')
dropout (float): Dropout percentage (Default=0.2)(Default = True)
activation_function: The activation function to use after convolutional layers (Default='relu')
seed: random seed (Default=0xFEEDFACE)
depth (int): Number of max pooling layers in encoder (Default=5)
dropout_at: Layers to perform dropout after (Default=[2,3])
initial_filters (int): Number of filters in first convolutional
layer (Default=32)
initial_filters (int): Number of filters in first convolutional layer (Default=32)
batch_norm (bool): True = use batch normalization (Default=True)
**kwargs: Additional parameters to pass to the function
Expand All @@ -89,7 +71,6 @@ def define_model(self, input_shape,

params = {
'activation': activation,
'data_format': data_format,
'kernel_initializer': keras.initializers.he_uniform(seed=seed),
'kernel_size': (3, 3),
'padding': 'same',
Expand Down Expand Up @@ -123,11 +104,11 @@ def define_model(self, input_shape,
name=f'up{depth + i + 1}', size=(2, 2))(net)
else:
up = keras.layers.Conv2DTranspose(
name='transConv6', filters=filters, data_format=data_format,
name='transConv6', filters=filters,
kernel_size=(2, 2), strides=(2, 2), padding='same')(net)
net = keras.layers.concatenate(
[up, convb_layers[f'conv{depth - i - 1}b']],
axis=concat_axis
axis=-1
)
net = keras.layers.Conv2D(
name=f'conv{depth + i + 1}a',
Expand All @@ -137,7 +118,7 @@ def define_model(self, input_shape,
filters=filters, **params)(net)
filters //= 2
net = keras.layers.Conv2D(name='Mask', filters=n_cl_out,
kernel_size=(1, 1), data_format=data_format,
kernel_size=(1, 1),
activation='sigmoid')(net)
model = keras.models.Model(inputs=[inputs], outputs=[net])

Expand All @@ -146,12 +127,12 @@ def define_model(self, input_shape,
model.compile(
loss=self.dice_coef_loss,
optimizer=self.optimizer,
metrics=["acc"],
metrics=["acc"]
)

return model

def dice_coef_loss(self, y_true, y_pred):
def dice_coef_loss(self, y_true, y_pred, smooth=1.0):
"""Dice coefficient loss.
Calculate the -log(Dice Coefficient) loss
Expand All @@ -164,7 +145,7 @@ def dice_coef_loss(self, y_true, y_pred):
float: -log(Dice cofficient) metric
"""
intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2, 3))
smooth=1.0

term1 = -tf.math.log(tf.constant(2.0) * intersection + smooth)
term2 = tf.math.log(tf.reduce_sum(y_true, axis=(1, 2, 3))
+ tf.reduce_sum(y_pred, axis=(1, 2, 3)) + smooth)
Expand All @@ -175,13 +156,3 @@ def dice_coef_loss(self, y_true, y_pred):
loss = term1 + term2

return loss


CHANNEL_LAST = True
if CHANNEL_LAST:
concat_axis = -1
data_format = 'channels_last'
else:
concat_axis = 1
data_format = 'channels_first'

0 comments on commit f2839c2

Please sign in to comment.