Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Keras and TensorFlow Task Runner and related workspaces. #1174

Merged
merged 57 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
9dc4dad
keras and tf updated
tanwarsh Nov 26, 2024
9c294e1
Merge branch 'securefederatedai:develop' into keras_tf
tanwarsh Nov 26, 2024
ba9924d
skipped rebuild model for round 0
tanwarsh Nov 27, 2024
5547548
formatting issues fixed
tanwarsh Nov 27, 2024
40a869f
added keras in workflow
tanwarsh Nov 27, 2024
dadedd5
pipeline tensorflowversion check
tanwarsh Nov 27, 2024
3c18818
revert changes
tanwarsh Nov 27, 2024
bd9b6f0
revert changes
tanwarsh Nov 27, 2024
b895e18
removed extra line
tanwarsh Nov 27, 2024
b815535
workspace changes
tanwarsh Nov 27, 2024
fde7a56
removed python 3.8 for taskrunner workflow
tanwarsh Nov 27, 2024
93a6674
updated python version to 3.9
tanwarsh Nov 27, 2024
4e96c92
updated python version to 3.9
tanwarsh Nov 27, 2024
bf950ac
saved model name changes in tests
tanwarsh Nov 27, 2024
78093fc
code change to save model
tanwarsh Nov 27, 2024
a00ade2
keras_nlp workspace changes
tanwarsh Nov 27, 2024
40f465d
keras and tf updated
tanwarsh Nov 26, 2024
e493c3f
skipped rebuild model for round 0
tanwarsh Nov 27, 2024
da97430
formatting issues fixed
tanwarsh Nov 27, 2024
33b6bd2
added keras in workflow
tanwarsh Nov 27, 2024
26fdaaa
pipeline tensorflowversion check
tanwarsh Nov 27, 2024
13761c0
revert changes
tanwarsh Nov 27, 2024
89bcc58
revert changes
tanwarsh Nov 27, 2024
59996a7
removed extra line
tanwarsh Nov 27, 2024
3b0edc2
workspace changes
tanwarsh Nov 27, 2024
76653d6
removed python 3.8 for taskrunner workflow
tanwarsh Nov 27, 2024
7cfda5a
updated python version to 3.9
tanwarsh Nov 27, 2024
b59d186
updated python version to 3.9
tanwarsh Nov 27, 2024
a5c8670
saved model name changes in tests
tanwarsh Nov 27, 2024
b50760e
code change to save model
tanwarsh Nov 27, 2024
692b2c6
keras_nlp workspace changes
tanwarsh Nov 27, 2024
5e7eac6
Merge branch 'keras_tf' of https://github.com/tanwarsh/openfl into ke…
tanwarsh Nov 29, 2024
2f94d92
removed duplicate code
tanwarsh Dec 2, 2024
8c3d507
Merge branch 'develop' into keras_tf
tanwarsh Dec 2, 2024
4f81bb4
removed oython 3.8 from taskrunnner e2e workflow
tanwarsh Dec 2, 2024
6eef7c7
code changes as per comments
tanwarsh Dec 3, 2024
1e06db5
Merge branch 'develop' into keras_tf
tanwarsh Dec 3, 2024
b1aeef1
fix for formatting issue
tanwarsh Dec 3, 2024
6568333
keras cnn with compression code changes
tanwarsh Dec 3, 2024
85d1276
Merge branch 'develop' into keras_tf
tanwarsh Dec 4, 2024
25f091d
remove version changes
tanwarsh Dec 5, 2024
375a4d8
revert version changes
tanwarsh Dec 5, 2024
52d6de4
Merge branch 'develop' into keras_tf
tanwarsh Dec 5, 2024
38009f7
Merge branch 'develop' into keras_tf
tanwarsh Dec 6, 2024
919185c
Merge branch 'develop' into keras_tf
tanwarsh Dec 9, 2024
b710640
Merge branch 'develop' into keras_tf
tanwarsh Dec 10, 2024
3e77a2f
Merge branch 'develop' into keras_tf
tanwarsh Dec 11, 2024
b14d914
keep keras runner for tensorflow workspaces
tanwarsh Dec 11, 2024
94f2bc7
formatting issue
tanwarsh Dec 11, 2024
58363a4
removed keras as ke
tanwarsh Dec 11, 2024
7e379b3
comment changes
tanwarsh Dec 11, 2024
6461a8a
code changes
tanwarsh Dec 11, 2024
35d8106
Merge branch 'develop' into keras_tf
tanwarsh Dec 12, 2024
08e94ff
Merge branch 'develop' into keras_tf
tanwarsh Dec 12, 2024
8b5306b
code changes
tanwarsh Dec 12, 2024
3ad6624
code changes
tanwarsh Dec 13, 2024
d3c73d7
code changes
tanwarsh Dec 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/developer_ref/troubleshooting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
The following is a list of commonly reported issues in Open Federated Learning (|productName|). If you don't see your issue reported here, please submit a `Github issue
<https://github.com/intel/openfl/issues>`_ or contact us directly on `Slack <https://join.slack.com/t/openfl/shared_invite/zt-ovzbohvn-T5fApk05~YS_iZhjJ5yaTw>`_.

1. I see the error :code:`Cannot import name TensorFlowDataLoader from openfl.federated`
1. I see the error :code:`Cannot import name KerasDataLoader from openfl.federated`

|productName| currently uses conditional imports to attempt to be framework agnostic. If your task runner is derived from `KerasTaskRunner` or `TensorflowTaskRunner`, this error could come up if TensorFlow\*\ was not installed in your collaborator's virtual environment. If running on multi-node experiment, we recommend using the :code:`fx workspace export` and :code:`fx workspace import` commands, as this will ensure consistent modules between aggregator and collaborators.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2020-2021 Intel Corporation
# Copyright (C) 2020-2024 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

# all keys under 'collaborators' corresponds to a specific colaborator name the corresponding dictionary has data_name, data_path pairs.
Expand All @@ -8,4 +8,4 @@
# collaborator_name,data_directory_path

one,1
two,2
two,2
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ aggregator :
defaults : plan/defaults/aggregator.yaml
template : openfl.component.Aggregator
settings :
init_state_path : save/tf_2dunet_brats_init.pbuf
last_state_path : save/tf_2dunet_brats_latest.pbuf
best_state_path : save/tf_2dunet_brats_best.pbuf
init_state_path : save/init.pbuf
last_state_path : save/latest.pbuf
best_state_path : save/best.pbuf
rounds_to_train : 10
db_store_rounds : 2

Expand All @@ -20,7 +20,7 @@ collaborator :

data_loader :
defaults : plan/defaults/data_loader.yaml
template : src.tfbrats_inmemory.TensorFlowBratsInMemory
template : src.dataloader.KerasBratsInMemory
settings :
batch_size: 64
percent_train: 0.8
Expand All @@ -29,7 +29,7 @@ data_loader :

task_runner :
defaults : plan/defaults/task_runner.yaml
template : src.tf_2dunet.TensorFlow2DUNet
template : src.taskrunner.Keras2DUNet

network :
defaults : plan/defaults/network.yaml
Expand All @@ -38,7 +38,30 @@ assigner :
defaults : plan/defaults/assigner.yaml

tasks :
defaults : plan/defaults/tasks_tensorflow.yaml
aggregated_model_validation:
function : validate_task
kwargs :
batch_size : 32
apply : global
metrics :
- acc

locally_tuned_model_validation:
function : validate_task
kwargs :
batch_size : 32
apply : local
metrics :
- acc

train:
function : train_task
kwargs :
batch_size : 32
metrics :
- loss
epochs : 1


compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
defaults : plan/defaults/compression_pipeline.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
keras==3.6.0
nibabel
setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability
tensorflow==2.13
tensorflow==2.18.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

"""You may copy this file as the starting point of your own model."""

from openfl.federated import TensorFlowDataLoader
from openfl.federated import KerasDataLoader
from .brats_utils import load_from_nifti


class TensorFlowBratsInMemory(TensorFlowDataLoader):
"""TensorFlow Data Loader for the BraTS dataset."""
class KerasBratsInMemory(KerasDataLoader):
"""Keras Data Loader for the BraTS dataset."""

def __init__(self, data_path, batch_size, percent_train=0.8, pre_split_shuffle=True, **kwargs):
"""Initialize.
Expand Down
158 changes: 158 additions & 0 deletions openfl-workspace/keras_2dunet/src/taskrunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (C) 2020-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

import tensorflow as tf
import keras

from openfl.federated import KerasTaskRunner


class Keras2DUNet(KerasTaskRunner):
"""Initialize.

Args:
**kwargs: Additional parameters to pass to the function

"""

def __init__(self, **kwargs):
"""Initialize.

Args:
**kwargs: Additional parameters to pass to the function

"""
super().__init__(**kwargs)

self.model = self.build_model(self.data_loader.get_feature_shape(), use_upsampling=True, **kwargs)
self.model.summary(print_fn=self.logger.info, line_length=120)
self.initialize_tensorkeys_for_functions()


def build_model(self, input_shape,
use_upsampling=False,
n_cl_out=1,
dropout=0.2,
activation_function='relu',
seed=0xFEEDFACE,
depth=5,
dropout_at=None,
initial_filters=32,
batch_norm=True,
**kwargs):
"""Define the TensorFlow model.

Args:
input_shape: input shape of the model
use_upsampling (bool): True = use bilinear interpolation;
False = use transposed convolution (Default=False)
n_cl_out (int): Number of channels in input layer (Default=1)
dropout (float): Dropout percentage (Default=0.2)(Default = True)
activation_function: The activation function to use after convolutional layers (Default='relu')
seed: random seed (Default=0xFEEDFACE)
depth (int): Number of max pooling layers in encoder (Default=5)
dropout_at: Layers to perform dropout after (Default=[2,3])
initial_filters (int): Number of filters in first convolutional layer (Default=32)
batch_norm (bool): True = use batch normalization (Default=True)
**kwargs: Additional parameters to pass to the function

"""
if dropout_at is None:
dropout_at = [2, 3]

inputs = keras.layers.Input(shape=input_shape, name='Images')

if activation_function == 'relu':
activation = tf.nn.relu
elif activation_function == 'leakyrelu':
activation = tf.nn.leaky_relu

params = {
'activation': activation,
'kernel_initializer': keras.initializers.he_uniform(seed=seed),
'kernel_size': (3, 3),
'padding': 'same',
}

convb_layers = {}

net = inputs
filters = initial_filters
for i in range(depth):
name = f'conv{i + 1}a'
net = keras.layers.Conv2D(name=name, filters=filters, **params)(net)
if i in dropout_at:
net = keras.layers.Dropout(dropout)(net)
name = f'conv{i + 1}b'
net = keras.layers.Conv2D(name=name, filters=filters, **params)(net)
if batch_norm:
net = keras.layers.BatchNormalization()(net)
convb_layers[name] = net
# only pool if not last level
if i != depth - 1:
name = f'pool{i + 1}'
net = keras.layers.MaxPooling2D(name=name, pool_size=(2, 2))(net)
filters *= 2

# do the up levels
filters //= 2
for i in range(depth - 1):
if use_upsampling:
up = keras.layers.UpSampling2D(
name=f'up{depth + i + 1}', size=(2, 2))(net)
else:
up = keras.layers.Conv2DTranspose(
name='transConv6', filters=filters,
kernel_size=(2, 2), strides=(2, 2), padding='same')(net)
net = keras.layers.concatenate(
[up, convb_layers[f'conv{depth - i - 1}b']],
axis=-1
)
net = keras.layers.Conv2D(
name=f'conv{depth + i + 1}a',
filters=filters, **params)(net)
net = keras.layers.Conv2D(
name=f'conv{depth + i + 1}b',
filters=filters, **params)(net)
filters //= 2
net = keras.layers.Conv2D(name='Mask', filters=n_cl_out,
kernel_size=(1, 1),
activation='sigmoid')(net)
model = keras.models.Model(inputs=[inputs], outputs=[net])


self.optimizer = keras.optimizers.RMSprop(1e-2)
model.compile(
loss=self.dice_coef_loss,
optimizer=self.optimizer,
metrics=["acc"]
)

return model

def dice_coef_loss(self, y_true, y_pred, smooth=1.0):
"""Dice coefficient loss.

Calculate the -log(Dice Coefficient) loss

Args:
y_true: Ground truth annotation array
y_pred: Prediction array from model
smooth (float): Laplace smoothing factor (Default=1.0)
Returns:
float: -log(Dice cofficient) metric
"""
intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2, 3))

term1 = -tf.math.log(tf.constant(2.0) * intersection + smooth)
term2 = tf.math.log(tf.reduce_sum(y_true, axis=(1, 2, 3))
+ tf.reduce_sum(y_pred, axis=(1, 2, 3)) + smooth)

term1 = tf.reduce_mean(term1)
term2 = tf.reduce_mean(term2)

loss = term1 + term2

return loss
14 changes: 9 additions & 5 deletions openfl-workspace/keras_cnn_mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ aggregator :
defaults : plan/defaults/aggregator.yaml
template : openfl.component.Aggregator
settings :
init_state_path : save/keras_cnn_mnist_init.pbuf
best_state_path : save/keras_cnn_mnist_best.pbuf
last_state_path : save/keras_cnn_mnist_last.pbuf
init_state_path : save/init.pbuf
best_state_path : save/best.pbuf
last_state_path : save/last.pbuf
tanwarsh marked this conversation as resolved.
Show resolved Hide resolved
rounds_to_train : 10

collaborator :
Expand All @@ -19,15 +19,15 @@ collaborator :

data_loader :
defaults : plan/defaults/data_loader.yaml
template : src.tfmnist_inmemory.TensorFlowMNISTInMemory
template : src.dataloader.KerasMNISTInMemory
settings :
collaborator_count : 2
data_group_name : mnist
batch_size : 256

task_runner :
defaults : plan/defaults/task_runner.yaml
template : src.keras_cnn.KerasCNN
template : src.taskrunner.KerasCNN

network :
defaults : plan/defaults/network.yaml
Expand All @@ -40,3 +40,7 @@ tasks :

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
# To use different Compression Pipeline, uncomment the following lines
# template : openfl.pipelines.KCPipeline
# settings :
# n_clusters : 6
5 changes: 3 additions & 2 deletions openfl-workspace/keras_cnn_mnist/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy==1.23.5
tensorflow==2.13
keras==3.6.0
tensorflow==2.18.0

Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

"""You may copy this file as the starting point of your own model."""

from openfl.federated import TensorFlowDataLoader
from openfl.federated import KerasDataLoader
from .mnist_utils import load_mnist_shard


class TensorFlowMNISTInMemory(TensorFlowDataLoader):
"""TensorFlow Data Loader for MNIST Dataset."""
class KerasMNISTInMemory(KerasDataLoader):
"""Data Loader for MNIST Dataset."""

def __init__(self, data_path, batch_size, **kwargs):
"""
Expand Down
tanwarsh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (C) 2020-2021 Intel Corporation
# Copyright (C) 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

import tensorflow.keras as ke
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from keras.models import Sequential
from keras.layers import Conv2D
kta-intel marked this conversation as resolved.
Show resolved Hide resolved
from keras.layers import Dense
from keras.layers import Flatten

from openfl.federated import KerasTaskRunner

Expand Down Expand Up @@ -50,7 +49,7 @@ def build_model(self,
num_classes (int): The number of classes of the dataset

Returns:
tensorflow.python.keras.engine.sequential.Sequential: The model defined in Keras
keras.models.Sequential: The model defined in Keras

"""
model = Sequential()
Expand All @@ -72,14 +71,8 @@ def build_model(self,

model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=ke.losses.categorical_crossentropy,
optimizer=ke.optimizers.legacy.Adam(),
metrics=['accuracy'])

# initialize the optimizer variables
opt_vars = model.optimizer.variables()

for v in opt_vars:
v.initializer.run(session=self.sess)
model.compile(loss="categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])

return model
5 changes: 0 additions & 5 deletions openfl-workspace/keras_cnn_with_compression/plan/cols.yaml

This file was deleted.

7 changes: 0 additions & 7 deletions openfl-workspace/keras_cnn_with_compression/plan/data.yaml

This file was deleted.

Loading
Loading