diff --git a/openfl-workspace/tf_2dunet/README.md b/openfl-workspace/tf_2dunet/README.md index 12dab8fc2e8..049cd5122a9 100644 --- a/openfl-workspace/tf_2dunet/README.md +++ b/openfl-workspace/tf_2dunet/README.md @@ -14,18 +14,27 @@ To use a `tree` command, you have to install it first: `sudo apt-get install tre - `HGG`: glioblastoma scans - `LGG`: lower grade glioma scans -Let's pick `HGG`: `export SUBFOLDER=HGG`. The learning rate has been already tuned for this task, so you don't have to change it. If you pick `LGG`, all the next steps will be the same. +Let's pick `HGG`: `export SUBFOLDER=MICCAI_BraTS_2019_Data_Training/HGG`. The learning rate has been already tuned for this task, so you don't have to change it. If you pick `LGG`, all the next steps will be the same. 3) In order for each collaborator to use separate slice of data, we split main folder into `n` subfolders: ```bash +#!/bin/bash cd $DATA_PATH/$SUBFOLDER -i=0; -for f in *; -do - d=dir_$(printf $((i%n))); # change n to number of data slices (number of collaborators in federation) - mkdir -p $d; - mv "$f" $d; - let i++; + +n=2 # Set this to the number of directories you want to create + +# Get a list of all files and shuffle them +files=($(ls | shuf)) + +# Create the target directories if they don't exist +for ((i=0; i=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/openfl-workspace/tf_2dunet/src/brats_utils.py b/openfl-workspace/tf_2dunet/src/brats_utils.py deleted file mode 100644 index 653e26cbcac..00000000000 --- a/openfl-workspace/tf_2dunet/src/brats_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -"""You may copy this file as the starting point of your own model.""" - -import logging -import os - -import numpy as np - -from .nii_reader import nii_reader - -logger = logging.getLogger(__name__) - - -def train_val_split(features, labels, percent_train, shuffle): - """Train/validation splot of the BraTS dataset. - - Splits incoming feature and labels into training and validation. The value - of shuffle determines whether shuffling occurs before the split is performed. - - Args: - features: The input images - labels: The ground truth labels - percent_train (float): The percentage of the dataset that is training. - shuffle (bool): True = shuffle the dataset before the split - - Returns: - train_features: The input images for the training dataset - train_labels: The ground truth labels for the training dataset - val_features: The input images for the validation dataset - val_labels: The ground truth labels for the validation dataset - """ - - def split(lst, idx): - """Split a Python list into 2 lists. - - Args: - lst: The Python list to split - idx: The index where to split the list into 2 parts - - Returns: - Two lists - - """ - if idx < 0 or idx > len(lst): - raise ValueError('split was out of expected range.') - return lst[:idx], lst[idx:] - - nb_features = len(features) - nb_labels = len(labels) - if nb_features != nb_labels: - raise RuntimeError('Number of features and labels do not match.') - if shuffle: - new_order = np.random.permutation(np.arange(nb_features)) - features = features[new_order] - labels = labels[new_order] - split_idx = int(percent_train * nb_features) - train_features, val_features = split(lst=features, idx=split_idx) - train_labels, val_labels = split(lst=labels, idx=split_idx) - return train_features, train_labels, val_features, val_labels - - -def load_from_nifti(parent_dir, - percent_train, - shuffle, - channels_last=True, - task='whole_tumor', - **kwargs): - """Load the BraTS dataset from the NiFTI file format. - - Loads data from the parent directory (NIfTI files for whole brains are - assumed to be contained in subdirectories of the parent directory). - Performs a split of the data into training and validation, and the value - of shuffle determined whether shuffling is performed before this split - occurs - both split and shuffle are done in a way to - keep whole brains intact. The kwargs are passed to nii_reader. - - Args: - parent_dir: The parent directory for the BraTS data - percent_train (float): The percentage of the data to make the training dataset - shuffle (bool): True means shuffle the dataset order before the split - channels_last (bool): Input tensor uses channels as last dimension (Default is True) - task: Prediction task (Default is 'whole_tumor' prediction) - **kwargs: Variable arguments to pass to the function - - Returns: - train_features: The input images for the training dataset - train_labels: The ground truth labels for the training dataset - val_features: The input images for the validation dataset - val_labels: The ground truth labels for the validation dataset - - """ - path = os.path.join(parent_dir) - subdirs = os.listdir(path) - subdirs.sort() - if not subdirs: - raise SystemError(f'''{parent_dir} does not contain subdirectories. -Please make sure you have BraTS dataset downloaded -and located in data directory for this collaborator. - ''') - subdir_paths = [os.path.join(path, subdir) for subdir in subdirs] - - imgs_all = [] - msks_all = [] - for brain_path in subdir_paths: - these_imgs, these_msks = nii_reader( - brain_path=brain_path, - task=task, - channels_last=channels_last, - **kwargs - ) - # the needed files where not present if a tuple of None is returned - if these_imgs is None: - logger.debug(f'Brain subdirectory: {brain_path} did not contain the needed files.') - else: - imgs_all.append(these_imgs) - msks_all.append(these_msks) - - # converting to arrays to allow for numpy indexing used during split - imgs_all = np.array(imgs_all) - msks_all = np.array(msks_all) - - # note here that each is a list of 155 slices per brain, and so the - # split keeps brains intact - imgs_all_train, msks_all_train, imgs_all_val, msks_all_val = train_val_split( - features=imgs_all, - labels=msks_all, - percent_train=percent_train, - shuffle=shuffle - ) - # now concatenate the lists - imgs_train = np.concatenate(imgs_all_train, axis=0) - msks_train = np.concatenate(msks_all_train, axis=0) - imgs_val = np.concatenate(imgs_all_val, axis=0) - msks_val = np.concatenate(msks_all_val, axis=0) - - return imgs_train, msks_train, imgs_val, msks_val diff --git a/openfl-workspace/tf_2dunet/src/nii_reader.py b/openfl-workspace/tf_2dunet/src/dataloader.py similarity index 63% rename from openfl-workspace/tf_2dunet/src/nii_reader.py rename to openfl-workspace/tf_2dunet/src/dataloader.py index ba90a644b1f..11b442c1627 100644 --- a/openfl-workspace/tf_2dunet/src/nii_reader.py +++ b/openfl-workspace/tf_2dunet/src/dataloader.py @@ -1,13 +1,174 @@ -# Copyright (C) 2020-2021 Intel Corporation +# Copyright (C) 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """You may copy this file as the starting point of your own model.""" import os +import logging import nibabel as nib import numpy as np import numpy.ma as ma +from openfl.federated import TensorFlowDataLoader + + +logger = logging.getLogger(__name__) + + +class TensorFlowBratsInMemory(TensorFlowDataLoader): + """TensorFlow Data Loader for the BraTS dataset.""" + + def __init__(self, data_path, batch_size, percent_train=0.8, pre_split_shuffle=True, num_classes=1, + **kwargs): + """Initialize. + + Args: + data_path: The file path for the BraTS dataset + batch_size (int): The batch size to use + percent_train (float): The percentage of the data to use for training (Default=0.8) + pre_split_shuffle (bool): True= shuffle the dataset before + performing the train/validate split (Default=True) + **kwargs: Additional arguments, passed to super init and load_from_nifti + + Returns: + Data loader with BraTS data + """ + super().__init__(batch_size, **kwargs) + + X_train, y_train, X_valid, y_valid = load_from_nifti(parent_dir=data_path, + percent_train=percent_train, + shuffle=pre_split_shuffle, + **kwargs) + self.X_train = X_train + self.y_train = y_train + self.X_valid = X_valid + self.y_valid = y_valid + self.num_classes = num_classes + + +def train_val_split(features, labels, percent_train, shuffle): + """Train/validation splot of the BraTS dataset. + + Splits incoming feature and labels into training and validation. The value + of shuffle determines whether shuffling occurs before the split is performed. + + Args: + features: The input images + labels: The ground truth labels + percent_train (float): The percentage of the dataset that is training. + shuffle (bool): True = shuffle the dataset before the split + + Returns: + train_features: The input images for the training dataset + train_labels: The ground truth labels for the training dataset + val_features: The input images for the validation dataset + val_labels: The ground truth labels for the validation dataset + """ + + def split(lst, idx): + """Split a Python list into 2 lists. + + Args: + lst: The Python list to split + idx: The index where to split the list into 2 parts + + Returns: + Two lists + + """ + if idx < 0 or idx > len(lst): + raise ValueError('split was out of expected range.') + return lst[:idx], lst[idx:] + + nb_features = len(features) + nb_labels = len(labels) + if nb_features != nb_labels: + raise RuntimeError('Number of features and labels do not match.') + if shuffle: + new_order = np.random.permutation(np.arange(nb_features)) + features = features[new_order] + labels = labels[new_order] + split_idx = int(percent_train * nb_features) + train_features, val_features = split(lst=features, idx=split_idx) + train_labels, val_labels = split(lst=labels, idx=split_idx) + return train_features, train_labels, val_features, val_labels + + +def load_from_nifti(parent_dir, + percent_train, + shuffle, + channels_last=True, + task='whole_tumor', + **kwargs): + """Load the BraTS dataset from the NiFTI file format. + + Loads data from the parent directory (NIfTI files for whole brains are + assumed to be contained in subdirectories of the parent directory). + Performs a split of the data into training and validation, and the value + of shuffle determined whether shuffling is performed before this split + occurs - both split and shuffle are done in a way to + keep whole brains intact. The kwargs are passed to nii_reader. + + Args: + parent_dir: The parent directory for the BraTS data + percent_train (float): The percentage of the data to make the training dataset + shuffle (bool): True means shuffle the dataset order before the split + channels_last (bool): Input tensor uses channels as last dimension (Default is True) + task: Prediction task (Default is 'whole_tumor' prediction) + **kwargs: Variable arguments to pass to the function + + Returns: + train_features: The input images for the training dataset + train_labels: The ground truth labels for the training dataset + val_features: The input images for the validation dataset + val_labels: The ground truth labels for the validation dataset + + """ + path = os.path.join(parent_dir) + subdirs = os.listdir(path) + subdirs.sort() + if not subdirs: + raise SystemError(f'''{parent_dir} does not contain subdirectories. +Please make sure you have BraTS dataset downloaded +and located in data directory for this collaborator. + ''') + subdir_paths = [os.path.join(path, subdir) for subdir in subdirs] + + imgs_all = [] + msks_all = [] + for brain_path in subdir_paths: + these_imgs, these_msks = nii_reader( + brain_path=brain_path, + task=task, + channels_last=channels_last, + **kwargs + ) + # the needed files where not present if a tuple of None is returned + if these_imgs is None: + logger.debug(f'Brain subdirectory: {brain_path} did not contain the needed files.') + else: + imgs_all.append(these_imgs) + msks_all.append(these_msks) + + # converting to arrays to allow for numpy indexing used during split + imgs_all = np.array(imgs_all) + msks_all = np.array(msks_all) + + # note here that each is a list of 155 slices per brain, and so the + # split keeps brains intact + imgs_all_train, msks_all_train, imgs_all_val, msks_all_val = train_val_split( + features=imgs_all, + labels=msks_all, + percent_train=percent_train, + shuffle=shuffle + ) + # now concatenate the lists + imgs_train = np.concatenate(imgs_all_train, axis=0) + msks_train = np.concatenate(msks_all_train, axis=0) + imgs_val = np.concatenate(imgs_all_val, axis=0) + msks_val = np.concatenate(msks_all_val, axis=0) + + return imgs_train, msks_train, imgs_val, msks_val def parse_segments(seg, msk_modes): diff --git a/openfl-workspace/tf_2dunet/src/taskrunner.py b/openfl-workspace/tf_2dunet/src/taskrunner.py new file mode 100644 index 00000000000..1d3baa690db --- /dev/null +++ b/openfl-workspace/tf_2dunet/src/taskrunner.py @@ -0,0 +1,257 @@ +# Copyright (C) 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +import numpy as np +import tensorflow as tf + +from openfl.utilities import Metric +from openfl.federated import TensorFlowTaskRunner + +class TensorFlow2DUNet(TensorFlowTaskRunner): + """Initialize. + + Args: + **kwargs: Additional parameters to pass to the function + + """ + + def __init__(self, initial_filters=16, + depth=5, + batch_norm=True, + use_upsampling=False, + **kwargs): + """Initialize. + + Args: + **kwargs: Additional parameters to pass to the function + + """ + super().__init__(**kwargs) + + self.model = self.create_model( + input_shape=self.feature_shape, + n_cl_out=self.data_loader.num_classes, + initial_filters=initial_filters, + use_upsampling=use_upsampling, + depth=depth, + batch_norm=batch_norm, + **kwargs + ) + self.initialize_tensorkeys_for_functions() + + self.model.summary(print_fn=self.logger.info, line_length=120) + + def create_model(self, + input_shape, + n_cl_out=1, + use_upsampling=False, + dropout=0.2, + print_summary=True, + seed=816, + depth=5, + dropout_at=(2, 3), + initial_filters=16, + batch_norm=True, + **kwargs): + """Create the TensorFlow 3D U-Net CNN model. + + Args: + input_shape (list): input shape of the data + n_cl_out (int): Number of output classes in label (Default=1) + **kwargs: Additional parameters to pass to the function + + """ + + model = build_model(input_shape, + n_cl_out=n_cl_out, + use_upsampling=use_upsampling, + dropout=dropout, + print_summary=print_summary, + seed=seed, + depth=depth, + dropout_at=dropout_at, + initial_filters=initial_filters, + batch_norm=batch_norm) + + model.compile( + loss=dice_loss, + optimizer=tf.keras.optimizers.Adam(), + metrics=[dice_coef, soft_dice_coef], + ) + + return model + + def train_(self, batch_generator, metrics: list = None, **kwargs): + """Train single epoch. + + Override this function for custom training. + + Args: + batch_generator: Generator of training batches. + Each batch is a tuple of N train images and N train labels + where N is the batch size of the DataLoader of the current TaskRunner instance. + + epochs: Number of epochs to train. + metrics: Names of metrics to save. + """ + history = self.model.fit(batch_generator, + verbose=1, + **kwargs) + results = [] + for metric in metrics: + value = np.mean([history.history[metric]]) + results.append(Metric(name=metric, value=np.array(value))) + return results + + +def dice_coef(target, prediction, axis=(1, 2), smooth=0.0001): + """ + Sorenson Dice. + + Returns + ------- + dice coefficient (float) + """ + prediction = tf.round(prediction) # Round to 0 or 1 + + intersection = tf.reduce_sum(target * prediction, axis=axis) + union = tf.reduce_sum(target + prediction, axis=axis) + numerator = tf.constant(2.) * intersection + smooth + denominator = union + smooth + coef = numerator / denominator + + return tf.reduce_mean(coef) + + +def soft_dice_coef(target, prediction, axis=(1, 2), smooth=0.0001): + """ + Soft Sorenson Dice. + + Does not round the predictions to either 0 or 1. + + Returns + ------- + soft dice coefficient (float) + """ + intersection = tf.reduce_sum(target * prediction, axis=axis) + union = tf.reduce_sum(target + prediction, axis=axis) + numerator = tf.constant(2.) * intersection + smooth + denominator = union + smooth + coef = numerator / denominator + + return tf.reduce_mean(coef) + + +def dice_loss(target, prediction, axis=(1, 2), smooth=0.0001): + """ + Sorenson (Soft) Dice loss. + + Using -log(Dice) as the loss since it is better behaved. + Also, the log allows avoidance of the division which + can help prevent underflow when the numbers are very small. + + Returns + ------- + dice loss (float) + """ + intersection = tf.reduce_sum(prediction * target, axis=axis) + p = tf.reduce_sum(prediction, axis=axis) + t = tf.reduce_sum(target, axis=axis) + numerator = tf.reduce_mean(intersection + smooth) + denominator = tf.reduce_mean(t + p + smooth) + dice_loss = -tf.math.log(2. * numerator) + tf.math.log(denominator) + + return dice_loss + + +def build_model(input_shape, + n_cl_out=1, + use_upsampling=False, + dropout=0.2, + seed=816, + depth=5, + dropout_at=(2, 3), + initial_filters=16, + batch_norm=True, + **kwargs): + """Build the TensorFlow model. + + Args: + input_tensor: input shape ot the model + use_upsampling (bool): True = use bilinear interpolation; + False = use transposed convolution (Default=False) + n_cl_out (int): Number of channels in output layer (Default=1) + dropout (float): Dropout percentage (Default=0.2) + print_summary (bool): True = print the model summary (Default = True) + seed: random seed (Default=816) + depth (int): Number of max pooling layers in encoder (Default=5) + dropout_at: Layers to perform dropout after (Default=[2,3]) + initial_filters (int): Number of filters in first convolutional + layer (Default=16) + batch_norm (bool): True = use batch normalization (Default=True) + **kwargs: Additional parameters to pass to the function + """ + if (input_shape[0] % (2**depth)) > 0: + raise ValueError(f'Crop dimension must be a multiple of 2^(depth of U-Net) = {2**depth}') + + inputs = tf.keras.layers.Input(input_shape, name='brats_mr_image') + + activation = tf.keras.activations.relu + + params = {'kernel_size': (3, 3), 'activation': activation, + 'padding': 'same', + 'kernel_initializer': tf.keras.initializers.he_uniform(seed=seed)} + + convb_layers = {} + + net = inputs + filters = initial_filters + for i in range(depth): + name = f'conv{i + 1}a' + net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net) + if i in dropout_at: + net = tf.keras.layers.Dropout(dropout)(net) + name = f'conv{i + 1}b' + net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net) + if batch_norm: + net = tf.keras.layers.BatchNormalization()(net) + convb_layers[name] = net + # only pool if not last level + if i != depth - 1: + name = f'pool{i + 1}' + net = tf.keras.layers.MaxPooling2D(name=name, pool_size=(2, 2))(net) + filters *= 2 + + # do the up levels + filters //= 2 + for i in range(depth - 1): + if use_upsampling: + up = tf.keras.layers.UpSampling2D( + name=f'up{depth + i + 1}', size=(2, 2))(net) + else: + up = tf.keras.layers.Conv2DTranspose(name=f'transConv{depth + i + 1}', + filters=filters, + kernel_size=(2, 2), + strides=(2, 2), + padding='same')(net) + net = tf.keras.layers.concatenate( + [up, convb_layers[f'conv{depth - i - 1}b']], + axis=-1 + ) + net = tf.keras.layers.Conv2D( + name=f'conv{depth + i + 1}a', + filters=filters, **params)(net) + net = tf.keras.layers.Conv2D( + name=f'conv{depth + i + 1}b', + filters=filters, **params)(net) + filters //= 2 + + net = tf.keras.layers.Conv2D(name='prediction', filters=n_cl_out, + kernel_size=(1, 1), + activation='sigmoid')(net) + + model = tf.keras.models.Model(inputs=[inputs], outputs=[net]) + + return model \ No newline at end of file diff --git a/openfl-workspace/tf_2dunet/src/tf_2dunet.py b/openfl-workspace/tf_2dunet/src/tf_2dunet.py deleted file mode 100644 index 54c2ae2a896..00000000000 --- a/openfl-workspace/tf_2dunet/src/tf_2dunet.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""You may copy this file as the starting point of your own model.""" - -import tensorflow.compat.v1 as tf - -from openfl.federated import TensorFlowTaskRunnerV1 - -tf.disable_v2_behavior() - - -class TensorFlow2DUNet(TensorFlowTaskRunnerV1): - """Initialize. - - Args: - **kwargs: Additional parameters to pass to the function - - """ - - def __init__(self, **kwargs): - """Initialize. - - Args: - **kwargs: Additional parameters to pass to the function - - """ - super().__init__(**kwargs) - - self.create_model(**kwargs) - self.initialize_tensorkeys_for_functions() - - def create_model(self, - training_smoothing=32.0, - validation_smoothing=1.0, - **kwargs): - """Create the TensorFlow 2D U-Net model. - - Args: - training_smoothing (float): (Default=32.0) - validation_smoothing (float): (Default=1.0) - **kwargs: Additional parameters to pass to the function - - """ - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - config.intra_op_parallelism_threads = 112 - config.inter_op_parallelism_threads = 1 - self.sess = tf.Session(config=config) - - self.X = tf.placeholder(tf.float32, self.input_shape) - self.y = tf.placeholder(tf.float32, self.input_shape) - self.output = define_model(self.X, use_upsampling=True, **kwargs) - - self.loss = dice_coef_loss(self.y, self.output, smooth=training_smoothing) - self.loss_name = dice_coef_loss.__name__ - self.validation_metric = dice_coef( - self.y, self.output, smooth=validation_smoothing) - self.validation_metric_name = dice_coef.__name__ - - self.global_step = tf.train.get_or_create_global_step() - - self.tvars = tf.trainable_variables() - - self.optimizer = tf.train.RMSPropOptimizer(1e-2) - - self.gvs = self.optimizer.compute_gradients(self.loss, self.tvars) - self.train_step = self.optimizer.apply_gradients(self.gvs, - global_step=self.global_step) - - self.opt_vars = self.optimizer.variables() - - # FIXME: Do we really need to share the opt_vars? - # Two opt_vars for one tvar: gradient and square sum for RMSprop. - self.fl_vars = self.tvars + self.opt_vars - - self.initialize_globals() - - -def dice_coef(y_true, y_pred, smooth=1.0, **kwargs): - """Dice coefficient. - - Calculate the Dice Coefficient - - Args: - y_true: Ground truth annotation array - y_pred: Prediction array from model - smooth (float): Laplace smoothing factor (Default=1.0) - **kwargs: Additional parameters to pass to the function - - Returns: - float: Dice cofficient metric - - """ - intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3]) - coef = ( - (tf.constant(2.) * intersection + tf.constant(smooth)) - / (tf.reduce_sum(y_true, axis=[1, 2, 3]) - + tf.reduce_sum(y_pred, axis=[1, 2, 3]) + tf.constant(smooth)) - ) - return tf.reduce_mean(coef) - - -def dice_coef_loss(y_true, y_pred, smooth=1.0, **kwargs): - """Dice coefficient loss. - - Calculate the -log(Dice Coefficient) loss - - Args: - y_true: Ground truth annotation array - y_pred: Prediction array from model - smooth (float): Laplace smoothing factor (Default=1.0) - **kwargs: Additional parameters to pass to the function - - Returns: - float: -log(Dice cofficient) metric - - """ - intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2, 3)) - - term1 = -tf.log(tf.constant(2.0) * intersection + smooth) - term2 = tf.log(tf.reduce_sum(y_true, axis=(1, 2, 3)) - + tf.reduce_sum(y_pred, axis=(1, 2, 3)) + smooth) - - term1 = tf.reduce_mean(term1) - term2 = tf.reduce_mean(term2) - - loss = term1 + term2 - - return loss - - -CHANNEL_LAST = True -if CHANNEL_LAST: - concat_axis = -1 - data_format = 'channels_last' -else: - concat_axis = 1 - data_format = 'channels_first' - -tf.keras.backend.set_image_data_format(data_format) - - -def define_model(input_tensor, - use_upsampling=False, - n_cl_out=1, - dropout=0.2, - print_summary=True, - activation_function='relu', - seed=0xFEEDFACE, - depth=5, - dropout_at=None, - initial_filters=32, - batch_norm=True, - **kwargs): - """Define the TensorFlow model. - - Args: - input_tensor: input shape ot the model - use_upsampling (bool): True = use bilinear interpolation; - False = use transposed convolution (Default=False) - n_cl_out (int): Number of channels in input layer (Default=1) - dropout (float): Dropout percentage (Default=0.2) - print_summary (bool): True = print the model summary (Default = True) - activation_function: The activation function to use after convolutional - layers (Default='relu') - seed: random seed (Default=0xFEEDFACE) - depth (int): Number of max pooling layers in encoder (Default=5) - dropout_at: Layers to perform dropout after (Default=[2,3]) - initial_filters (int): Number of filters in first convolutional - layer (Default=32) - batch_norm (bool): True = use batch normalization (Default=True) - **kwargs: Additional parameters to pass to the function - - """ - if dropout_at is None: - dropout_at = [2, 3] - # Set keras learning phase to train - tf.keras.backend.set_learning_phase(True) - - # Don't initialize variables on the fly - tf.keras.backend.manual_variable_initialization(False) - - inputs = tf.keras.layers.Input(tensor=input_tensor, name='Images') - - if activation_function == 'relu': - activation = tf.nn.relu - elif activation_function == 'leakyrelu': - activation = tf.nn.leaky_relu - - params = { - 'activation': activation, - 'data_format': data_format, - 'kernel_initializer': tf.keras.initializers.he_uniform(seed=seed), - 'kernel_size': (3, 3), - 'padding': 'same', - } - - convb_layers = {} - - net = inputs - filters = initial_filters - for i in range(depth): - name = f'conv{i + 1}a' - net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net) - if i in dropout_at: - net = tf.keras.layers.Dropout(dropout)(net) - name = f'conv{i + 1}b' - net = tf.keras.layers.Conv2D(name=name, filters=filters, **params)(net) - if batch_norm: - net = tf.keras.layers.BatchNormalization()(net) - convb_layers[name] = net - # only pool if not last level - if i != depth - 1: - name = f'pool{i + 1}' - net = tf.keras.layers.MaxPooling2D(name=name, pool_size=(2, 2))(net) - filters *= 2 - - # do the up levels - filters //= 2 - for i in range(depth - 1): - if use_upsampling: - up = tf.keras.layers.UpSampling2D( - name=f'up{depth + i + 1}', size=(2, 2))(net) - else: - up = tf.keras.layers.Conv2DTranspose( - name='transConv6', filters=filters, data_format=data_format, - kernel_size=(2, 2), strides=(2, 2), padding='same')(net) - net = tf.keras.layers.concatenate( - [up, convb_layers[f'conv{depth - i - 1}b']], - axis=concat_axis - ) - net = tf.keras.layers.Conv2D( - name=f'conv{depth + i + 1}a', - filters=filters, **params)(net) - net = tf.keras.layers.Conv2D( - name=f'conv{depth + i + 1}b', - filters=filters, **params)(net) - filters //= 2 - - net = tf.keras.layers.Conv2D(name='Mask', filters=n_cl_out, - kernel_size=(1, 1), data_format=data_format, - activation='sigmoid')(net) - - model = tf.keras.models.Model(inputs=[inputs], outputs=[net]) - - if print_summary: - print(model.summary()) - - return net diff --git a/openfl-workspace/tf_2dunet/src/tfbrats_inmemory.py b/openfl-workspace/tf_2dunet/src/tfbrats_inmemory.py deleted file mode 100644 index 49b4484fc2e..00000000000 --- a/openfl-workspace/tf_2dunet/src/tfbrats_inmemory.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""You may copy this file as the starting point of your own model.""" - -from openfl.federated import TensorFlowDataLoader -from .brats_utils import load_from_nifti - - -class TensorFlowBratsInMemory(TensorFlowDataLoader): - """TensorFlow Data Loader for the BraTS dataset.""" - - def __init__(self, data_path, batch_size, percent_train=0.8, pre_split_shuffle=True, **kwargs): - """Initialize. - - Args: - data_path: The file path for the BraTS dataset - batch_size (int): The batch size to use - percent_train (float): The percentage of the data to use for training (Default=0.8) - pre_split_shuffle (bool): True= shuffle the dataset before - performing the train/validate split (Default=True) - **kwargs: Additional arguments, passed to super init and load_from_nifti - - Returns: - Data loader with BraTS data - """ - super().__init__(batch_size, **kwargs) - - X_train, y_train, X_valid, y_valid = load_from_nifti(parent_dir=data_path, - percent_train=percent_train, - shuffle=pre_split_shuffle, - **kwargs) - self.X_train = X_train - self.y_train = y_train - self.X_valid = X_valid - self.y_valid = y_valid diff --git a/openfl-workspace/tf_cnn_mnist/src/taskrunner.py b/openfl-workspace/tf_cnn_mnist/src/taskrunner.py index da618fbb5fd..ea11110edd1 100644 --- a/openfl-workspace/tf_cnn_mnist/src/taskrunner.py +++ b/openfl-workspace/tf_cnn_mnist/src/taskrunner.py @@ -83,18 +83,7 @@ def train_(self, batch_generator, metrics: list = None, **kwargs): epochs: Number of epochs to train. metrics: Names of metrics to save. """ - if metrics is None: - metrics = [] - - model_metrics_names = self.model.metrics_names - - for param in metrics: - if param not in model_metrics_names: - raise ValueError( - f'TensorFlowTaskRunner does not support specifying new metrics. ' - f'Param_metrics = {metrics}, model_metrics_names = {model_metrics_names}' - ) - + history = self.model.fit(batch_generator, verbose=1, **kwargs) diff --git a/openfl/federated/__init__.py b/openfl/federated/__init__.py index 54cd35515f4..7849172bf45 100644 --- a/openfl/federated/__init__.py +++ b/openfl/federated/__init__.py @@ -9,7 +9,7 @@ from .data import DataLoader # NOQA if importlib.util.find_spec('tensorflow'): - from .task import TensorFlowTaskRunner, TensorFlowTaskRunnerV1, KerasTaskRunner, FederatedModel # NOQA + from .task import TensorFlowTaskRunner, KerasTaskRunner, FederatedModel # NOQA from .data import TensorFlowDataLoader, KerasDataLoader, FederatedDataSet # NOQA if importlib.util.find_spec('torch'): from .task import PyTorchTaskRunner, FederatedModel # NOQA diff --git a/openfl/federated/task/__init__.py b/openfl/federated/task/__init__.py index fa6628d0472..a0837db6872 100644 --- a/openfl/federated/task/__init__.py +++ b/openfl/federated/task/__init__.py @@ -16,7 +16,7 @@ from .runner import TaskRunner # NOQA if importlib.util.find_spec('tensorflow'): - from .runner_tf import TensorFlowTaskRunner, TensorFlowTaskRunnerV1 # NOQA + from .runner_tf import TensorFlowTaskRunner # NOQA from .runner_keras import KerasTaskRunner # NOQA from .fl_model import FederatedModel # NOQA if importlib.util.find_spec('torch'): diff --git a/openfl/federated/task/runner_tf.py b/openfl/federated/task/runner_tf.py index e13522666bc..56b61c5b649 100644 --- a/openfl/federated/task/runner_tf.py +++ b/openfl/federated/task/runner_tf.py @@ -9,9 +9,6 @@ from openfl.utilities.split import split_tensor_dict_for_holdouts from .runner import TaskRunner -import tensorflow.compat.v1 -from tqdm import tqdm - class TensorFlowTaskRunner(TaskRunner): """The base model for Keras models in the federation.""" @@ -543,437 +540,4 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False): self.required_tensorkeys_for_function['validate_task']['apply=global'] += [ TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) for tensor_name in local_model_dict_val - ] - - -class TensorFlowTaskRunnerV1(TaskRunner): - """ - Base class for TensorFlow models in the Federated Learning solution. - - child classes should have __init__ function signature (self, data, kwargs), - and should overwrite at least the following while defining the model - """ - - def __init__(self, **kwargs): - """ - Initialize. - - Args: - **kwargs: Additional parameters to pass to the function - """ - tensorflow.compat.v1.disable_v2_behavior() - - super().__init__(**kwargs) - - self.assign_ops = None - self.placeholders = None - - self.tvar_assign_ops = None - self.tvar_placeholders = None - - # construct the shape needed for the input features - self.input_shape = (None,) + self.data_loader.get_feature_shape() - - # Required tensorkeys for all public functions in TensorFlowTaskRunner - self.required_tensorkeys_for_function = {} - - # Required tensorkeys for all public functions in TensorFlowTaskRunner - self.required_tensorkeys_for_function = {} - - # tensorflow session - self.sess = None - # input featrures to the model - self.X = None - # input labels to the model - self.y = None - # optimizer train step operation - self.train_step = None - # model loss function - self.loss = None - # model output tensor - self.output = None - # function used to validate the model outputs against labels - self.validation_metric = None - # tensorflow trainable variables - self.tvars = None - # self.optimizer.variables() once self.optimizer is defined - self.opt_vars = None - # self.tvars + self.opt_vars - self.fl_vars = None - - def rebuild_model(self, round_num, input_tensor_dict, validation=False): - """ - Parse tensor names and update weights of model. Handles the optimizer treatment. - - Returns: - None - """ - if self.opt_treatment == 'RESET': - self.reset_opt_vars() - self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - elif (round_num > 0 and self.opt_treatment == 'CONTINUE_GLOBAL' - and not validation): - self.set_tensor_dict(input_tensor_dict, with_opt_vars=True) - else: - self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) - - def train_task(self, col_name, round_num, input_tensor_dict, - epochs=1, use_tqdm=False, **kwargs): - """ - Perform the training. - - Is expected to perform draws randomly, without replacement until data is exausted. Then - data is replaced and shuffled and draws continue. - - Args: - use_tqdm (bool): True = use tqdm to print a progress - bar (Default=False) - epochs (int): Number of epochs to train - Returns: - float: loss metric - """ - batch_size = self.data_loader.batch_size - - if kwargs['batch_size']: - batch_size = kwargs['batch_size'] - - # rebuild model with updated weights - self.rebuild_model(round_num, input_tensor_dict) - - tensorflow.compat.v1.keras.backend.set_learning_phase(True) - losses = [] - - for epoch in range(epochs): - self.logger.info(f'Run {epoch} epoch of {round_num} round') - # get iterator for batch draws (shuffling happens here) - gen = self.data_loader.get_train_loader(batch_size) - if use_tqdm: - gen = tqdm.tqdm(gen, desc='training epoch') - - for (X, y) in gen: - losses.append(self.train_batch(X, y)) - - # Output metric tensors (scalar) - origin = col_name - tags = ('trained',) - output_metric_dict = { - TensorKey( - self.loss_name, origin, round_num, True, ('metric',) - ): np.array(np.mean(losses)) - } - - # output model tensors (Doesn't include TensorKey) - output_model_dict = self.get_tensor_dict(with_opt_vars=True) - global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) - - # Create global tensorkeys - global_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in global_model_dict.items() - } - # Create tensorkeys that should stay local - local_tensorkey_model_dict = { - TensorKey(tensor_name, origin, round_num, False, tags): - nparray for tensor_name, nparray in local_model_dict.items() - } - # The train/validate aggregated function of the next round will - # look for the updated model parameters. - # This ensures they will be resolved locally - next_local_tensorkey_model_dict = { - TensorKey( - tensor_name, origin, round_num + 1, False, ('model',) - ): nparray for tensor_name, nparray in local_model_dict.items()} - - global_tensor_dict = { - **output_metric_dict, - **global_tensorkey_model_dict - } - local_tensor_dict = { - **local_tensorkey_model_dict, - **next_local_tensorkey_model_dict - } - - # Update the required tensors if they need to be pulled from - # the aggregator - # TODO this logic can break if different collaborators have different - # roles between rounds. - # For example, if a collaborator only performs validation in the first - # round but training in the second, it has no way of knowing the - # optimizer state tensor names to request from the aggregator because - # these are only created after training occurs. A work around could - # involve doing a single epoch of training on random data to get the - # optimizer names, and then throwing away the model. - if self.opt_treatment == 'CONTINUE_GLOBAL': - self.initialize_tensorkeys_for_functions(with_opt_vars=True) - - return global_tensor_dict, local_tensor_dict - - def train_batch(self, X, y): - """ - Train the model on a single batch. - - Args: - X: Input to the model - y: Ground truth label to the model - - Returns: - float: loss metric - """ - feed_dict = {self.X: X, self.y: y} - - # run the train step and return the loss - _, loss = self.sess.run([self.train_step, self.loss], feed_dict=feed_dict) - - return loss - - def validate_task(self, col_name, round_num, - input_tensor_dict, use_tqdm=False, **kwargs): - """ - Run validation. - - Returns: - dict: {: } - """ - batch_size = self.data_loader.batch_size - - if kwargs['batch_size']: - batch_size = kwargs['batch_size'] - - self.rebuild_model(round_num, input_tensor_dict, validation=True) - - tensorflow.compat.v1.keras.backend.set_learning_phase(False) - - score = 0 - - gen = self.data_loader.get_valid_loader(batch_size) - if use_tqdm: - gen = tqdm.tqdm(gen, desc='validating') - - for X, y in gen: - weight = X.shape[0] / self.data_loader.get_valid_data_size() - _, s = self.validate_(X, y) - score += s * weight - - origin = col_name - suffix = 'validate' - if kwargs['apply'] == 'local': - suffix += '_local' - else: - suffix += '_agg' - tags = ('metric', suffix) - output_tensor_dict = { - TensorKey( - self.validation_metric_name, origin, round_num, True, tags - ): np.array(score)} - - # return empty dict for local metrics - return output_tensor_dict, {} - - def validate_(self, X, y): - """Validate the model on a single local batch. - - Args: - X: Input to the model - y: Ground truth label to the model - - Returns: - float: loss metric - - """ - feed_dict = {self.X: X, self.y: y} - - return self.sess.run( - [self.output, self.validation_metric], feed_dict=feed_dict) - - def get_tensor_dict(self, with_opt_vars=True): - """Get the dictionary weights. - - Get the weights from the tensor - - Args: - with_opt_vars (bool): Specify if we also want to get the variables - of the optimizer - - Returns: - dict: The weight dictionary {: } - - """ - if with_opt_vars is True: - variables = self.fl_vars - else: - variables = self.tvars - - # FIXME: do this in one call? - return {var.name: val for var, val in zip( - variables, self.sess.run(variables))} - - def set_tensor_dict(self, tensor_dict, with_opt_vars): - """Set the tensor dictionary. - - Set the model weights with a tensor - dictionary: {: }. - - Args: - tensor_dict (dict): The model weights dictionary - with_opt_vars (bool): Specify if we also want to set the variables - of the optimizer - - Returns: - None - """ - if with_opt_vars: - self.assign_ops, self.placeholders = tf_set_tensor_dict( - tensor_dict, self.sess, self.fl_vars, - self.assign_ops, self.placeholders - ) - else: - self.tvar_assign_ops, self.tvar_placeholders = tf_set_tensor_dict( - tensor_dict, - self.sess, - self.tvars, - self.tvar_assign_ops, - self.tvar_placeholders - ) - - def reset_opt_vars(self): - """Reinitialize the optimizer variables.""" - for v in self.opt_vars: - v.initializer.run(session=self.sess) - - def initialize_globals(self): - """Initialize Global Variables. - - Initialize all global variables - - Returns: - None - """ - self.sess.run(tensorflow.compat.v1.global_variables_initializer()) - - def _get_weights_names(self, with_opt_vars=True): - """Get the weights. - - Args: - with_opt_vars (bool): Specify if we also want to get the variables - of the optimizer. - - Returns: - list : The weight names list - """ - if with_opt_vars is True: - variables = self.fl_vars - else: - variables = self.tvars - - return [var.name for var in variables] - - def get_required_tensorkeys_for_function(self, func_name, **kwargs): - """ - Get the required tensors for specified function that could be called as part of a task. - - By default, this is just all of the layers and optimizer of the model. - - Returns: - list : [TensorKey] - """ - if func_name == 'validate': - local_model = 'apply=' + str(kwargs['apply']) - return self.required_tensorkeys_for_function[func_name][local_model] - else: - return self.required_tensorkeys_for_function[func_name] - - def initialize_tensorkeys_for_functions(self, with_opt_vars=False): - """ - Set the required tensors for all publicly accessible methods \ - that could be called as part of a task. - - By default, this is just all of the layers and optimizer of the model. - Custom tensors should be added to this function - - """ - # TODO there should be a way to programmatically iterate through - # all of the methods in the class and declare the tensors. - # For now this is done manually - - output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars) - global_model_dict, local_model_dict = split_tensor_dict_for_holdouts( - self.logger, output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) - if not with_opt_vars: - global_model_dict_val = global_model_dict - local_model_dict_val = local_model_dict - else: - output_model_dict = self.get_tensor_dict(with_opt_vars=False) - global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts( - self.logger, - output_model_dict, - **self.tensor_dict_split_fn_kwargs - ) - - self.required_tensorkeys_for_function['train_task'] = [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) - for tensor_name in global_model_dict] - self.required_tensorkeys_for_function['train_task'] += [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) - for tensor_name in local_model_dict] - - # Validation may be performed on local or aggregated (global) - # model, so there is an extra lookup dimension for kwargs - self.required_tensorkeys_for_function['validate_task'] = {} - # TODO This is not stateless. The optimizer will not be - self.required_tensorkeys_for_function['validate_task']['apply=local'] = [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('trained',)) - for tensor_name in { - **global_model_dict_val, - **local_model_dict_val - } - ] - self.required_tensorkeys_for_function['validate']['apply=global'] = [ - TensorKey(tensor_name, 'GLOBAL', 0, False, ('model',)) - for tensor_name in global_model_dict_val - ] - self.required_tensorkeys_for_function['validate']['apply=global'] += [ - TensorKey(tensor_name, 'LOCAL', 0, False, ('model',)) - for tensor_name in local_model_dict_val - ] - - -# FIXME: what's a nicer construct than this? ugly interface. Perhaps we -# get an object with an assumed interface that lets is set/get these? -# Note that this will return the assign_ops and placeholder nodes it uses -# if called with None, it will create them. -# to avoid inflating the graph, caller should keep these and pass them back -# What if we want to set a different group of vars in the middle? -# It is good if it is the subset of the original variables. -def tf_set_tensor_dict(tensor_dict, session, variables, - assign_ops=None, placeholders=None): - """Tensorflow set tensor dictionary. - - Args: - tensor_dict: Dictionary of tensors - session: TensorFlow session - variables: TensorFlow variables - assign_ops: TensorFlow operations (Default=None) - placeholders: TensorFlow placeholders (Default=None) - - Returns: - assign_ops, placeholders - - """ - if placeholders is None: - placeholders = { - v.name: tensorflow.compat.v1.placeholder(v.dtype, shape=v.shape) for v in variables - } - if assign_ops is None: - assign_ops = { - v.name: tensorflow.compat.v1.assign(v, placeholders[v.name]) for v in variables - } - - for k, v in tensor_dict.items(): - session.run(assign_ops[k], feed_dict={placeholders[k]: v}) - - return assign_ops, placeholders + ] \ No newline at end of file