Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX Taskrunner Workspace with Keras 3 #1334

Merged
merged 20 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/about/features_index/taskrunner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ STEP 1: Create a Workspace
- :code:`tf_cnn_histology`: a workspace with a simple `TensorFlow <http://tensorflow.org>`__ CNN model that will download the `Colorectal Histology <https://zenodo.org/record/53169#.XGZemKwzbmG>`_ dataset and train in a federation.
- :code:`keras/histology`: a workspace with a simple `PyTorch <http://pytorch.org/>`__ CNN model that will download the `Colorectal Histology <https://zenodo.org/record/53169#.XGZemKwzbmG>`_ dataset and train in a federation.
- :code:`torch/mnist`: a workspace with a simple `PyTorch <http://pytorch.org>`__ CNN model that will download the `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset and train in a federation.
- :code:`keras/jax/mnist`: a workspace with a simple `Keras <http://keras.io/>`__ CNN model that will download the `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset and train in a federation with jax as backend. You can export the environment variable KERAS_BACKEND to configure your backend. Available backend options are: "jax", "tensorflow", "torch". Example:

.. code-block:: shell

$ export KERAS_BACKEND="jax"
tanwarsh marked this conversation as resolved.
Show resolved Hide resolved
tanwarsh marked this conversation as resolved.
Show resolved Hide resolved

.. note::

Please ensure KERAS_BACKEND is set in the environment where you plan on using OpenFL before executing any fx command.

See the complete list of available templates.

Expand Down
2 changes: 2 additions & 0 deletions openfl-workspace/keras/jax/mnist/.workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
current_plan_name: default

5 changes: 5 additions & 0 deletions openfl-workspace/keras/jax/mnist/plan/cols.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2020-2025 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

collaborators:

7 changes: 7 additions & 0 deletions openfl-workspace/keras/jax/mnist/plan/data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2020-2025 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

# collaborator_name,data_directory_path
one,1


2 changes: 2 additions & 0 deletions openfl-workspace/keras/jax/mnist/plan/defaults
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
../../workspace/plan/defaults

42 changes: 42 additions & 0 deletions openfl-workspace/keras/jax/mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (C) 2020-2025 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

aggregator :
defaults : plan/defaults/aggregator.yaml
template : openfl.component.Aggregator
settings :
init_state_path : save/init.pbuf
best_state_path : save/best.pbuf
last_state_path : save/last.pbuf
rounds_to_train : 10

collaborator :
defaults : plan/defaults/collaborator.yaml
template : openfl.component.Collaborator
settings :
delta_updates : false
opt_treatment : RESET

data_loader :
defaults : plan/defaults/data_loader.yaml
template : src.dataloader.JAXMNISTInMemory
settings :
collaborator_count : 2
data_group_name : mnist
batch_size : 256

task_runner :
defaults : plan/defaults/task_runner.yaml
template : src.taskrunner.JAXCNN

network :
defaults : plan/defaults/network.yaml

assigner :
defaults : plan/defaults/assigner.yaml

tasks :
defaults : plan/defaults/tasks_keras.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
3 changes: 3 additions & 0 deletions openfl-workspace/keras/jax/mnist/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
jax==0.5.0
keras==3.8.0
tensorflow==2.18.0
tanwarsh marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions openfl-workspace/keras/jax/mnist/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""You may copy this file as the starting point of your own model."""
42 changes: 42 additions & 0 deletions openfl-workspace/keras/jax/mnist/src/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

from openfl.federated import KerasDataLoader
from .mnist_utils import load_mnist_shard


class JAXMNISTInMemory(KerasDataLoader):
tanwarsh marked this conversation as resolved.
Show resolved Hide resolved
"""Data Loader for MNIST Dataset."""

def __init__(self, data_path, batch_size, **kwargs):
"""
Initialize.

Args:
data_path: File path for the dataset
batch_size (int): The batch size for the data loader
**kwargs: Additional arguments, passed to super init and load_mnist_shard
"""
super().__init__(batch_size, **kwargs)

try:
int(data_path)
except:
raise ValueError(
"Expected `%s` to be representable as `int`, as it refers to the data shard " +
"number used by the collaborator.",
data_path
)

_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard(
shard_num=int(data_path), **kwargs
)

self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
self.y_valid = y_valid

self.num_classes = num_classes
107 changes: 107 additions & 0 deletions openfl-workspace/keras/jax/mnist/src/mnist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

from logging import getLogger

import numpy as np
import keras

logger = getLogger(__name__)


def one_hot(labels, classes):
"""
One Hot encode a vector.

Args:
labels (list): List of labels to onehot encode
classes (int): Total number of categorical classes

Returns:
np.array: Matrix of one-hot encoded labels
"""
return np.eye(classes)[labels]


def _load_raw_datashards(shard_num, collaborator_count):
"""
Load the raw data by shard.

Returns tuples of the dataset shard divided into training and validation.

Args:
shard_num (int): The shard number to use
collaborator_count (int): The number of collaborators in the federation

Returns:
2 tuples: (image, label) of the training, validation dataset
"""
(X_train_tot, y_train_tot), (X_valid_tot, y_valid_tot) = keras.datasets.mnist.load_data()

# create the shards
shard_num = int(shard_num)
X_train = X_train_tot[shard_num::collaborator_count]
y_train = y_train_tot[shard_num::collaborator_count]

X_valid = X_valid_tot[shard_num::collaborator_count]
y_valid = y_valid_tot[shard_num::collaborator_count]

return (X_train, y_train), (X_valid, y_valid)


def load_mnist_shard(shard_num, collaborator_count, categorical=True,
channels_last=True, **kwargs):
"""
Load the MNIST dataset.

Args:
shard_num (int): The shard to use from the dataset
collaborator_count (int): The number of collaborators in the federation
categorical (bool): True = convert the labels to one-hot encoded
vectors (Default = True)
channels_last (bool): True = The input images have the channels
last (Default = True)
**kwargs: Additional parameters to pass to the function

Returns:
list: The input shape
int: The number of classes
numpy.ndarray: The training data
numpy.ndarray: The training labels
numpy.ndarray: The validation data
numpy.ndarray: The validation labels
"""
img_rows, img_cols = 28, 28
num_classes = 10

(X_train, y_train), (X_valid, y_valid) = _load_raw_datashards(
shard_num, collaborator_count
)

if channels_last:
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
else:
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_valid = X_valid.reshape(X_valid.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)

X_train = X_train.astype('float32')
X_valid = X_valid.astype('float32')
X_train /= 255
X_valid /= 255

logger.info(f'MNIST > X_train Shape : {X_train.shape}')
logger.info(f'MNIST > y_train Shape : {y_train.shape}')
logger.info(f'MNIST > Train Samples : {X_train.shape[0]}')
logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}')

if categorical:
# convert class vectors to binary class matrices
y_train = one_hot(y_train, num_classes)
y_valid = one_hot(y_valid, num_classes)

return input_shape, num_classes, X_train, y_train, X_valid, y_valid
120 changes: 120 additions & 0 deletions openfl-workspace/keras/jax/mnist/src/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""
import jax
import keras

class CNNModel(keras.Model):
def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
x,
y,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=training,
)
loss = self.compute_loss(x, y, y_pred)
return loss, (y_pred, non_trainable_variables)

def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
x, y = data

# Get the gradient function.
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)

# Compute the gradients.
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
x,
y,
training=True,
)

# Update trainable variables and optimizer variables.
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)

# Update metrics.
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars

# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
new_metrics_vars,
)
return logs, state

def test_step(self, state, data):
# Unpack the data.
x, y = data
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state

# Compute predictions and loss.
y_pred, non_trainable_variables = self.stateless_call(
trainable_variables,
non_trainable_variables,
x,
training=False,
)
loss = self.compute_loss(x, y, y_pred)

# Update metrics.
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars

# Return metric logs and updated state variables.
state = (
trainable_variables,
non_trainable_variables,
new_metrics_vars,
)
return logs, state
Loading