Skip to content

Commit

Permalink
JAX Taskrunner Workspace with Keras 3 (#1334)
Browse files Browse the repository at this point in the history
* created workspace

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* code changes to remove setting keras backend by code

Signed-off-by: yes <[email protected]>

* code changes

Signed-off-by: yes <[email protected]>

* updated docs

Signed-off-by: yes <[email protected]>

---------

Signed-off-by: yes <[email protected]>
  • Loading branch information
tanwarsh authored Feb 7, 2025
1 parent 01e67fc commit 7566a30
Show file tree
Hide file tree
Showing 14 changed files with 418 additions and 14 deletions.
9 changes: 9 additions & 0 deletions docs/about/features_index/taskrunner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ STEP 1: Create a Workspace
- :code:`tf_cnn_histology`: a workspace with a simple `TensorFlow <http://tensorflow.org>`__ CNN model that will download the `Colorectal Histology <https://zenodo.org/record/53169#.XGZemKwzbmG>`_ dataset and train in a federation.
- :code:`keras/histology`: a workspace with a simple `PyTorch <http://pytorch.org/>`__ CNN model that will download the `Colorectal Histology <https://zenodo.org/record/53169#.XGZemKwzbmG>`_ dataset and train in a federation.
- :code:`torch/mnist`: a workspace with a simple `PyTorch <http://pytorch.org>`__ CNN model that will download the `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset and train in a federation.
- :code:`keras/jax/mnist`: a workspace with a simple `Keras <http://keras.io/>`__ CNN model that will download the `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset and train in a federation with jax as backend. You can export the environment variable KERAS_BACKEND to configure your backend. Available backend options are: "jax", "tensorflow", "torch". Example:

.. code-block:: shell
$ export KERAS_BACKEND="jax"
.. note::

Please ensure KERAS_BACKEND is set in the environment where you plan on using OpenFL before executing any fx command.

See the complete list of available templates.

Expand Down
2 changes: 2 additions & 0 deletions openfl-workspace/keras/jax/mnist/.workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
current_plan_name: default

5 changes: 5 additions & 0 deletions openfl-workspace/keras/jax/mnist/plan/cols.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2020-2025 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

collaborators:

7 changes: 7 additions & 0 deletions openfl-workspace/keras/jax/mnist/plan/data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2020-2025 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

# collaborator_name,data_directory_path
one,1


2 changes: 2 additions & 0 deletions openfl-workspace/keras/jax/mnist/plan/defaults
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
../../workspace/plan/defaults

42 changes: 42 additions & 0 deletions openfl-workspace/keras/jax/mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (C) 2020-2025 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

aggregator :
defaults : plan/defaults/aggregator.yaml
template : openfl.component.Aggregator
settings :
init_state_path : save/init.pbuf
best_state_path : save/best.pbuf
last_state_path : save/last.pbuf
rounds_to_train : 10

collaborator :
defaults : plan/defaults/collaborator.yaml
template : openfl.component.Collaborator
settings :
delta_updates : false
opt_treatment : RESET

data_loader :
defaults : plan/defaults/data_loader.yaml
template : src.dataloader.JAXMNISTInMemory
settings :
collaborator_count : 2
data_group_name : mnist
batch_size : 256

task_runner :
defaults : plan/defaults/task_runner.yaml
template : src.taskrunner.JAXCNN

network :
defaults : plan/defaults/network.yaml

assigner :
defaults : plan/defaults/assigner.yaml

tasks :
defaults : plan/defaults/tasks_keras.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
3 changes: 3 additions & 0 deletions openfl-workspace/keras/jax/mnist/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
jax==0.5.0
keras==3.8.0
tensorflow==2.18.0
3 changes: 3 additions & 0 deletions openfl-workspace/keras/jax/mnist/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""You may copy this file as the starting point of your own model."""
42 changes: 42 additions & 0 deletions openfl-workspace/keras/jax/mnist/src/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

from openfl.federated import KerasDataLoader
from .mnist_utils import load_mnist_shard


class JAXMNISTInMemory(KerasDataLoader):
"""Data Loader for MNIST Dataset."""

def __init__(self, data_path, batch_size, **kwargs):
"""
Initialize.
Args:
data_path: File path for the dataset
batch_size (int): The batch size for the data loader
**kwargs: Additional arguments, passed to super init and load_mnist_shard
"""
super().__init__(batch_size, **kwargs)

try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)

self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
self.y_valid = y_valid

self.num_classes = num_classes
107 changes: 107 additions & 0 deletions openfl-workspace/keras/jax/mnist/src/mnist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

from logging import getLogger

import numpy as np
import keras

logger = getLogger(__name__)


def one_hot(labels, classes):
"""
One Hot encode a vector.
Args:
labels (list): List of labels to onehot encode
classes (int): Total number of categorical classes
Returns:
np.array: Matrix of one-hot encoded labels
"""
return np.eye(classes)[labels]


def _load_raw_datashards(shard_num, collaborator_count):
"""
Load the raw data by shard.
Returns tuples of the dataset shard divided into training and validation.
Args:
shard_num (int): The shard number to use
collaborator_count (int): The number of collaborators in the federation
Returns:
2 tuples: (image, label) of the training, validation dataset
"""
(X_train_tot, y_train_tot), (X_valid_tot, y_valid_tot) = keras.datasets.mnist.load_data()

# create the shards
shard_num = int(shard_num)
X_train = X_train_tot[shard_num::collaborator_count]
y_train = y_train_tot[shard_num::collaborator_count]

X_valid = X_valid_tot[shard_num::collaborator_count]
y_valid = y_valid_tot[shard_num::collaborator_count]

return (X_train, y_train), (X_valid, y_valid)


def load_mnist_shard(shard_num, collaborator_count, categorical=True,
channels_last=True, **kwargs):
"""
Load the MNIST dataset.
Args:
shard_num (int): The shard to use from the dataset
collaborator_count (int): The number of collaborators in the federation
categorical (bool): True = convert the labels to one-hot encoded
vectors (Default = True)
channels_last (bool): True = The input images have the channels
last (Default = True)
**kwargs: Additional parameters to pass to the function
Returns:
list: The input shape
int: The number of classes
numpy.ndarray: The training data
numpy.ndarray: The training labels
numpy.ndarray: The validation data
numpy.ndarray: The validation labels
"""
img_rows, img_cols = 28, 28
num_classes = 10

(X_train, y_train), (X_valid, y_valid) = _load_raw_datashards(
shard_num, collaborator_count
)

if channels_last:
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
else:
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_valid = X_valid.reshape(X_valid.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)

X_train = X_train.astype('float32')
X_valid = X_valid.astype('float32')
X_train /= 255
X_valid /= 255

logger.info(f'MNIST > X_train Shape : {X_train.shape}')
logger.info(f'MNIST > y_train Shape : {y_train.shape}')
logger.info(f'MNIST > Train Samples : {X_train.shape[0]}')
logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}')

if categorical:
# convert class vectors to binary class matrices
y_train = one_hot(y_train, num_classes)
y_valid = one_hot(y_valid, num_classes)

return input_shape, num_classes, X_train, y_train, X_valid, y_valid
120 changes: 120 additions & 0 deletions openfl-workspace/keras/jax/mnist/src/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""
import jax
import keras

class CNNModel(keras.Model):
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
x,
y,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=training,
)
loss = self.compute_loss(x, y, y_pred)
return loss, (y_pred, non_trainable_variables)

def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y = data

# Get the gradient function.
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)

# Compute the gradients.
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
x,
y,
training=True,
)

# Update trainable variables and optimizer variables.
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)

# Update metrics.
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars

# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
new_metrics_vars,
)
return logs, state

def test_step(self, state, data):
# Unpack the data.
x, y = data
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state

# Compute predictions and loss.
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=False,
)
loss = self.compute_loss(x, y, y_pred)

# Update metrics.
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars

# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
new_metrics_vars,
)
return logs, state
Loading

0 comments on commit 7566a30

Please sign in to comment.