-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move interactive API examples to openfl-contrib
As we deprecate the interactive API by a more generalized form with the workflow API, we decided to keep the examples in this repo.
- Loading branch information
Showing
274 changed files
with
25,045 additions
and
0 deletions.
There are no files selected for viewing
106 changes: 106 additions & 0 deletions
106
openfl_contrib_tutorials/interactive_api/Flax_CNN_CIFAR/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Federated FLAX CIFAR-10 CNN Tutorial | ||
|
||
### 1. About dataset | ||
|
||
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. | ||
|
||
Define the below param in envoy.yaml config to shard the dataset across participants/envoy. | ||
- rank_worldsize | ||
|
||
### 2. About model | ||
|
||
A simple multi-layer CNN is used with XLA compiled and Auto-grad based parameter updates. | ||
Definition provided in the notebook. | ||
|
||
### 3. Notebook Overview | ||
|
||
1. Class `CustomTrainState` - Subclasses `flax.training.TrainState` | ||
- Variable `opt_vars` to keep track of generic optimizer variables. | ||
- Method `update_state` to update the OpenFL `ModelInterface` registered state with the new_state returned within the `TaskInterface` registered training loop. | ||
|
||
2. Method `create_train_state`: Creates a new `TrainState` by encapsulating model layer definitions, random model parameters, and optax optimizer state. | ||
|
||
3. Method `apply_model` (`@jax.jit` decorated function): It takes a TrainState, images, and labels as parameters. It computes and returns the gradients, loss, and accuracy. These gradients are applied to a given state in the `update_model` method (`@jax.jit` decorated function) and a new TrainState instance is returned. | ||
|
||
### 4. How to run this tutorial (without TLS and locally as a simulation): | ||
|
||
0. Pre-requisites: | ||
|
||
- Nvidia Driver >= 495.29.05 | ||
- CUDA >= 11.1.105 | ||
- cuDNN >= 8 | ||
|
||
Activate virtual environment (Python - 3.8.10) and install packages from requirements.txt | ||
|
||
Set the variable `DEFAULT_DEVICE to 'CPU' or 'GPU'` in `start_envoy.sh` and notebook to enforce/control the execution platform. | ||
|
||
```sh | ||
cd Flax_CNN_CIFAR | ||
pip install -r requirements.txt | ||
``` | ||
|
||
1. Run director: | ||
|
||
```sh | ||
cd director | ||
./start_director.sh | ||
``` | ||
|
||
2. Run envoy: | ||
|
||
```sh | ||
cd envoy | ||
./start_envoy.sh "envoy_identifier" envoy_config.yaml | ||
``` | ||
|
||
Optional: start second envoy: | ||
|
||
- Copy `envoy` folder to another place and follow the same process as above: | ||
|
||
```sh | ||
./start_envoy.sh "envoy_identifier_2" envoy_config_2.yaml | ||
``` | ||
|
||
3. Run `FLAX_CIFAR10_CNN.ipynb` jupyter notebook: | ||
|
||
```sh | ||
cd workspace | ||
jupyter lab FLAX_CIFAR10_CNN.ipynb | ||
``` | ||
|
||
4. Visualization: | ||
|
||
``` | ||
tensorboard --logdir logs/ | ||
``` | ||
|
||
|
||
### 5. Known issues | ||
|
||
1. #### CUDA_ERROR_OUT_OF_MEMORY Exception - JAX XLA pre-allocates 90% of the GPU at start | ||
|
||
- set XLA_PYTHON_CLIENT_PREALLOCATE to start with a small memory footprint. | ||
``` | ||
%env XLA_PYTHON_CLIENT_PREALLOCATE=false | ||
``` | ||
OR | ||
|
||
- Below flag to restrict max GPU allocation to 50% | ||
``` | ||
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.5 | ||
``` | ||
|
||
|
||
2. #### Tensorflow pre-allocates 90% of the GPU (Potential OOM Errors). | ||
|
||
- set TF_FORCE_GPU_ALLOW_GROWTH to start with a small memory footprint. | ||
``` | ||
%env TF_FORCE_GPU_ALLOW_GROWTH=true | ||
``` | ||
|
||
3. #### DNN library Not found error | ||
|
||
- Make sure the jaxlib(cuda version), Nvidia Driver, CUDA and cuDNN versions are specific, relevant and compatible as per the documentation. | ||
- Reference: | ||
- CUDA and cuDNN Compatibility Matrix: https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html | ||
- Official JAX Compatible CUDA Releases: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html |
6 changes: 6 additions & 0 deletions
6
openfl_contrib_tutorials/interactive_api/Flax_CNN_CIFAR/director/director_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
settings: | ||
listen_host: localhost | ||
listen_port: 50055 | ||
sample_shape: ['32', '32', '3'] # [[shape], channel] | ||
target_shape: ['1'] | ||
envoy_health_check_period: 5 # in seconds |
4 changes: 4 additions & 0 deletions
4
openfl_contrib_tutorials/interactive_api/Flax_CNN_CIFAR/director/start_director.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
fx director start --disable-tls -c director_config.yaml |
111 changes: 111 additions & 0 deletions
111
openfl_contrib_tutorials/interactive_api/Flax_CNN_CIFAR/envoy/cifar10_shard_descriptor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright (C) 2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""CIFAR10 Shard Descriptor (using `TFDS` API)""" | ||
import jax.numpy as jnp | ||
import logging | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
|
||
from typing import List, Tuple | ||
from openfl.interface.interactive_api.shard_descriptor import ShardDescriptor | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CIFAR10ShardDescriptor(ShardDescriptor): | ||
""" | ||
CIFAR10 Shard Descriptor | ||
This example is based on `tfds` data loader. | ||
Note that the ingestion of any model/task requires an iterable dataloader. | ||
Hence, it is possible to utilize these pipelines without explicit need of a | ||
new interface. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
rank_worldsize: str = '1, 1', | ||
**kwargs | ||
) -> None: | ||
"""Download/Prepare CIFAR10 dataset""" | ||
self.rank, self.worldsize = tuple(int(num) for num in rank_worldsize.split(',')) | ||
|
||
# Load dataset | ||
train_ds, valid_ds = self._download_and_prepare_dataset(self.rank, self.worldsize) | ||
|
||
# Set attributes | ||
self._sample_shape = train_ds['image'].shape[1:] | ||
self._target_shape = tf.expand_dims(train_ds['label'], -1).shape[1:] | ||
|
||
self.splits = { | ||
'train': train_ds, | ||
'valid': valid_ds | ||
} | ||
|
||
def _download_and_prepare_dataset(self, rank: int, worldsize: int) -> Tuple[dict]: | ||
""" | ||
Download, Cache CIFAR10 and prepare `tfds` builder. | ||
Provide `rank` and `worldsize` to virtually split dataset across shards | ||
uniquely for each client for simulation purposes. | ||
Returns: | ||
Tuple (train_dict, test_dict) of dictionary with JAX DeviceArray (image and label) | ||
dict['image'] -> DeviceArray float32 | ||
dict['label'] -> DeviceArray int32 | ||
{'image' : DeviceArray(...), 'label' : DeviceArray(...)} | ||
""" | ||
|
||
dataset_builder = tfds.builder('cifar10') | ||
dataset_builder.download_and_prepare() | ||
|
||
datasets = dataset_builder.as_dataset() | ||
|
||
train_shard_size = int(len(datasets['train']) / worldsize) | ||
test_shard_size = int(len(datasets['test']) / worldsize) | ||
|
||
self.train_segment = f'train[{train_shard_size * (rank - 1)}:{train_shard_size * rank}]' | ||
self.test_segment = f'test[{test_shard_size * (rank - 1)}:{test_shard_size * rank}]' | ||
train_dataset = dataset_builder.as_dataset(split=self.train_segment, batch_size=-1) | ||
test_dataset = dataset_builder.as_dataset(split=self.test_segment, batch_size=-1) | ||
train_ds = tfds.as_numpy(train_dataset) | ||
test_ds = tfds.as_numpy(test_dataset) | ||
|
||
train_ds['image'] = jnp.float32(train_ds['image']) / 255. | ||
test_ds['image'] = jnp.float32(test_ds['image']) / 255. | ||
train_ds['label'] = jnp.int32(train_ds['label']) | ||
test_ds['label'] = jnp.int32(test_ds['label']) | ||
|
||
return train_ds, test_ds | ||
|
||
def get_shard_dataset_types(self) -> List[str]: | ||
"""Get available split names""" | ||
return list(self.splits) | ||
|
||
def get_split(self, name: str) -> tf.data.Dataset: | ||
"""Return a shard dataset by type.""" | ||
if name not in self.splits: | ||
raise Exception(f'Split name `{name}` not found.' | ||
f' Expected one of {list(self.splits.keys())}') | ||
return self.splits[name] | ||
|
||
@property | ||
def sample_shape(self) -> List[str]: | ||
"""Return the sample shape info.""" | ||
return list(map(str, self._sample_shape)) | ||
|
||
@property | ||
def target_shape(self) -> List[str]: | ||
"""Return the target shape info.""" | ||
return list(map(str, self._target_shape)) | ||
|
||
@property | ||
def dataset_description(self) -> str: | ||
"""Return the dataset description.""" | ||
n_train = len(self.splits['train']['label']) | ||
n_test = len(self.splits['valid']['label']) | ||
return (f'CIFAR10 dataset, Shard Segments {self.train_segment}/{self.test_segment}, ' | ||
f'rank/world {self.rank}/{self.worldsize}.' | ||
f'\n num_samples [Train/Valid]: [{n_train}/{n_test}]') |
9 changes: 9 additions & 0 deletions
9
openfl_contrib_tutorials/interactive_api/Flax_CNN_CIFAR/envoy/envoy_config_1.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
params: | ||
cuda_devices: [] | ||
|
||
optional_plugin_components: {} | ||
|
||
shard_descriptor: | ||
template: cifar10_shard_descriptor.CIFAR10ShardDescriptor | ||
params: | ||
rank_worldsize: 1, 2 |
17 changes: 17 additions & 0 deletions
17
openfl_contrib_tutorials/interactive_api/Flax_CNN_CIFAR/envoy/start_envoy.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
set -e | ||
ENVOY_NAME=$1 | ||
ENVOY_CONF=$2 | ||
|
||
DEFAULT_DEVICE='CPU' | ||
|
||
if [[ $DEFAULT_DEVICE == 'CPU' ]] | ||
then | ||
export JAX_PLATFORMS="cpu" # Force XLA to use CPU | ||
export CUDA_VISIBLE_DEVICES='-1' # Force TF to use CPU | ||
else | ||
export XLA_PYTHON_CLIENT_PREALLOCATE=false | ||
export TF_FORCE_GPU_ALLOW_GROWTH=true | ||
fi | ||
|
||
fx envoy start -n "$ENVOY_NAME" --disable-tls --envoy-config-path "$ENVOY_CONF" -dh localhost -dp 50055 |
5 changes: 5 additions & 0 deletions
5
openfl_contrib_tutorials/interactive_api/Flax_CNN_CIFAR/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
jax | ||
jaxlib | ||
tensorflow==2.13 | ||
tensorflow-datasets==4.6.0 |
Oops, something went wrong.