-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
JAX Taskrunner Workspace with Keras 3 (#1334)
* created workspace Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes to remove setting keras backend by code Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * updated docs Signed-off-by: yes <[email protected]> --------- Signed-off-by: yes <[email protected]>
- Loading branch information
Showing
14 changed files
with
418 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
current_plan_name: default | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
collaborators: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
# collaborator_name,data_directory_path | ||
one,1 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
../../workspace/plan/defaults | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
aggregator : | ||
defaults : plan/defaults/aggregator.yaml | ||
template : openfl.component.Aggregator | ||
settings : | ||
init_state_path : save/init.pbuf | ||
best_state_path : save/best.pbuf | ||
last_state_path : save/last.pbuf | ||
rounds_to_train : 10 | ||
|
||
collaborator : | ||
defaults : plan/defaults/collaborator.yaml | ||
template : openfl.component.Collaborator | ||
settings : | ||
delta_updates : false | ||
opt_treatment : RESET | ||
|
||
data_loader : | ||
defaults : plan/defaults/data_loader.yaml | ||
template : src.dataloader.JAXMNISTInMemory | ||
settings : | ||
collaborator_count : 2 | ||
data_group_name : mnist | ||
batch_size : 256 | ||
|
||
task_runner : | ||
defaults : plan/defaults/task_runner.yaml | ||
template : src.taskrunner.JAXCNN | ||
|
||
network : | ||
defaults : plan/defaults/network.yaml | ||
|
||
assigner : | ||
defaults : plan/defaults/assigner.yaml | ||
|
||
tasks : | ||
defaults : plan/defaults/tasks_keras.yaml | ||
|
||
compression_pipeline : | ||
defaults : plan/defaults/compression_pipeline.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
jax==0.5.0 | ||
keras==3.8.0 | ||
tensorflow==2.18.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""You may copy this file as the starting point of your own model.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""You may copy this file as the starting point of your own model.""" | ||
|
||
from openfl.federated import KerasDataLoader | ||
from .mnist_utils import load_mnist_shard | ||
|
||
|
||
class JAXMNISTInMemory(KerasDataLoader): | ||
"""Data Loader for MNIST Dataset.""" | ||
|
||
def __init__(self, data_path, batch_size, **kwargs): | ||
""" | ||
Initialize. | ||
Args: | ||
data_path: File path for the dataset | ||
batch_size (int): The batch size for the data loader | ||
**kwargs: Additional arguments, passed to super init and load_mnist_shard | ||
""" | ||
super().__init__(batch_size, **kwargs) | ||
|
||
try: | ||
int(data_path) | ||
except: | ||
raise ValueError( | ||
"Expected `%s` to be representable as `int`, as it refers to the data shard " + | ||
"number used by the collaborator.", | ||
data_path | ||
) | ||
|
||
_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard( | ||
shard_num=int(data_path), **kwargs | ||
) | ||
|
||
self.X_train = X_train | ||
self.y_train = y_train | ||
self.X_valid = X_valid | ||
self.y_valid = y_valid | ||
|
||
self.num_classes = num_classes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""You may copy this file as the starting point of your own model.""" | ||
|
||
from logging import getLogger | ||
|
||
import numpy as np | ||
import keras | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
def one_hot(labels, classes): | ||
""" | ||
One Hot encode a vector. | ||
Args: | ||
labels (list): List of labels to onehot encode | ||
classes (int): Total number of categorical classes | ||
Returns: | ||
np.array: Matrix of one-hot encoded labels | ||
""" | ||
return np.eye(classes)[labels] | ||
|
||
|
||
def _load_raw_datashards(shard_num, collaborator_count): | ||
""" | ||
Load the raw data by shard. | ||
Returns tuples of the dataset shard divided into training and validation. | ||
Args: | ||
shard_num (int): The shard number to use | ||
collaborator_count (int): The number of collaborators in the federation | ||
Returns: | ||
2 tuples: (image, label) of the training, validation dataset | ||
""" | ||
(X_train_tot, y_train_tot), (X_valid_tot, y_valid_tot) = keras.datasets.mnist.load_data() | ||
|
||
# create the shards | ||
shard_num = int(shard_num) | ||
X_train = X_train_tot[shard_num::collaborator_count] | ||
y_train = y_train_tot[shard_num::collaborator_count] | ||
|
||
X_valid = X_valid_tot[shard_num::collaborator_count] | ||
y_valid = y_valid_tot[shard_num::collaborator_count] | ||
|
||
return (X_train, y_train), (X_valid, y_valid) | ||
|
||
|
||
def load_mnist_shard(shard_num, collaborator_count, categorical=True, | ||
channels_last=True, **kwargs): | ||
""" | ||
Load the MNIST dataset. | ||
Args: | ||
shard_num (int): The shard to use from the dataset | ||
collaborator_count (int): The number of collaborators in the federation | ||
categorical (bool): True = convert the labels to one-hot encoded | ||
vectors (Default = True) | ||
channels_last (bool): True = The input images have the channels | ||
last (Default = True) | ||
**kwargs: Additional parameters to pass to the function | ||
Returns: | ||
list: The input shape | ||
int: The number of classes | ||
numpy.ndarray: The training data | ||
numpy.ndarray: The training labels | ||
numpy.ndarray: The validation data | ||
numpy.ndarray: The validation labels | ||
""" | ||
img_rows, img_cols = 28, 28 | ||
num_classes = 10 | ||
|
||
(X_train, y_train), (X_valid, y_valid) = _load_raw_datashards( | ||
shard_num, collaborator_count | ||
) | ||
|
||
if channels_last: | ||
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1) | ||
X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 1) | ||
input_shape = (img_rows, img_cols, 1) | ||
else: | ||
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols) | ||
X_valid = X_valid.reshape(X_valid.shape[0], 1, img_rows, img_cols) | ||
input_shape = (1, img_rows, img_cols) | ||
|
||
X_train = X_train.astype('float32') | ||
X_valid = X_valid.astype('float32') | ||
X_train /= 255 | ||
X_valid /= 255 | ||
|
||
logger.info(f'MNIST > X_train Shape : {X_train.shape}') | ||
logger.info(f'MNIST > y_train Shape : {y_train.shape}') | ||
logger.info(f'MNIST > Train Samples : {X_train.shape[0]}') | ||
logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}') | ||
|
||
if categorical: | ||
# convert class vectors to binary class matrices | ||
y_train = one_hot(y_train, num_classes) | ||
y_valid = one_hot(y_valid, num_classes) | ||
|
||
return input_shape, num_classes, X_train, y_train, X_valid, y_valid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""You may copy this file as the starting point of your own model.""" | ||
import jax | ||
import keras | ||
|
||
class CNNModel(keras.Model): | ||
def compute_loss_and_updates( | ||
self, | ||
trainable_variables, | ||
non_trainable_variables, | ||
x, | ||
y, | ||
training=False, | ||
): | ||
y_pred, non_trainable_variables = self.stateless_call( | ||
trainable_variables, | ||
non_trainable_variables, | ||
x, | ||
training=training, | ||
) | ||
loss = self.compute_loss(x, y, y_pred) | ||
return loss, (y_pred, non_trainable_variables) | ||
|
||
def train_step(self, state, data): | ||
( | ||
trainable_variables, | ||
non_trainable_variables, | ||
optimizer_variables, | ||
metrics_variables, | ||
) = state | ||
x, y = data | ||
|
||
# Get the gradient function. | ||
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True) | ||
|
||
# Compute the gradients. | ||
(loss, (y_pred, non_trainable_variables)), grads = grad_fn( | ||
trainable_variables, | ||
non_trainable_variables, | ||
x, | ||
y, | ||
training=True, | ||
) | ||
|
||
# Update trainable variables and optimizer variables. | ||
( | ||
trainable_variables, | ||
optimizer_variables, | ||
) = self.optimizer.stateless_apply( | ||
optimizer_variables, grads, trainable_variables | ||
) | ||
|
||
# Update metrics. | ||
new_metrics_vars = [] | ||
logs = {} | ||
for metric in self.metrics: | ||
this_metric_vars = metrics_variables[ | ||
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) | ||
] | ||
if metric.name == "loss": | ||
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss) | ||
else: | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, y, y_pred | ||
) | ||
logs[metric.name] = metric.stateless_result(this_metric_vars) | ||
new_metrics_vars += this_metric_vars | ||
|
||
# Return metric logs and updated state variables. | ||
state = ( | ||
trainable_variables, | ||
non_trainable_variables, | ||
optimizer_variables, | ||
new_metrics_vars, | ||
) | ||
return logs, state | ||
|
||
def test_step(self, state, data): | ||
# Unpack the data. | ||
x, y = data | ||
( | ||
trainable_variables, | ||
non_trainable_variables, | ||
metrics_variables, | ||
) = state | ||
|
||
# Compute predictions and loss. | ||
y_pred, non_trainable_variables = self.stateless_call( | ||
trainable_variables, | ||
non_trainable_variables, | ||
x, | ||
training=False, | ||
) | ||
loss = self.compute_loss(x, y, y_pred) | ||
|
||
# Update metrics. | ||
new_metrics_vars = [] | ||
logs = {} | ||
for metric in self.metrics: | ||
this_metric_vars = metrics_variables[ | ||
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables) | ||
] | ||
if metric.name == "loss": | ||
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss) | ||
else: | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, y, y_pred | ||
) | ||
logs[metric.name] = metric.stateless_result(this_metric_vars) | ||
new_metrics_vars += this_metric_vars | ||
|
||
# Return metric logs and updated state variables. | ||
state = ( | ||
trainable_variables, | ||
non_trainable_variables, | ||
new_metrics_vars, | ||
) | ||
return logs, state |
Oops, something went wrong.