Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move interactive API examples to openfl-contrib #9

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions openfl_contrib_tutorials/interactive_api/Flax_CNN_CIFAR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Federated FLAX CIFAR-10 CNN Tutorial

### 1. About dataset

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

Define the below param in envoy.yaml config to shard the dataset across participants/envoy.
- rank_worldsize

### 2. About model

A simple multi-layer CNN is used with XLA compiled and Auto-grad based parameter updates.
Definition provided in the notebook.

### 3. Notebook Overview

1. Class `CustomTrainState` - Subclasses `flax.training.TrainState`
- Variable `opt_vars` to keep track of generic optimizer variables.
- Method `update_state` to update the OpenFL `ModelInterface` registered state with the new_state returned within the `TaskInterface` registered training loop.

2. Method `create_train_state`: Creates a new `TrainState` by encapsulating model layer definitions, random model parameters, and optax optimizer state.

3. Method `apply_model` (`@jax.jit` decorated function): It takes a TrainState, images, and labels as parameters. It computes and returns the gradients, loss, and accuracy. These gradients are applied to a given state in the `update_model` method (`@jax.jit` decorated function) and a new TrainState instance is returned.

### 4. How to run this tutorial (without TLS and locally as a simulation):

0. Pre-requisites:

- Nvidia Driver >= 495.29.05
- CUDA >= 11.1.105
- cuDNN >= 8

Activate virtual environment (Python - 3.8.10) and install packages from requirements.txt

Set the variable `DEFAULT_DEVICE to 'CPU' or 'GPU'` in `start_envoy.sh` and notebook to enforce/control the execution platform.

```sh
cd Flax_CNN_CIFAR
pip install -r requirements.txt
```

1. Run director:

```sh
cd director
./start_director.sh
```

2. Run envoy:

```sh
cd envoy
./start_envoy.sh "envoy_identifier" envoy_config.yaml
```

Optional: start second envoy:

- Copy `envoy` folder to another place and follow the same process as above:

```sh
./start_envoy.sh "envoy_identifier_2" envoy_config_2.yaml
```

3. Run `FLAX_CIFAR10_CNN.ipynb` jupyter notebook:

```sh
cd workspace
jupyter lab FLAX_CIFAR10_CNN.ipynb
```

4. Visualization:

```
tensorboard --logdir logs/
```


### 5. Known issues

1. #### CUDA_ERROR_OUT_OF_MEMORY Exception - JAX XLA pre-allocates 90% of the GPU at start

- set XLA_PYTHON_CLIENT_PREALLOCATE to start with a small memory footprint.
```
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
```
OR

- Below flag to restrict max GPU allocation to 50%
```
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.5
```


2. #### Tensorflow pre-allocates 90% of the GPU (Potential OOM Errors).

- set TF_FORCE_GPU_ALLOW_GROWTH to start with a small memory footprint.
```
%env TF_FORCE_GPU_ALLOW_GROWTH=true
```

3. #### DNN library Not found error

- Make sure the jaxlib(cuda version), Nvidia Driver, CUDA and cuDNN versions are specific, relevant and compatible as per the documentation.
- Reference:
- CUDA and cuDNN Compatibility Matrix: https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html
- Official JAX Compatible CUDA Releases: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
settings:
listen_host: localhost
listen_port: 50055
sample_shape: ['32', '32', '3'] # [[shape], channel]
target_shape: ['1']
envoy_health_check_period: 5 # in seconds
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
set -e

fx director start --disable-tls -c director_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""CIFAR10 Shard Descriptor (using `TFDS` API)"""
import jax.numpy as jnp
import logging
import tensorflow as tf
import tensorflow_datasets as tfds

from typing import List, Tuple
from openfl.interface.interactive_api.shard_descriptor import ShardDescriptor

logger = logging.getLogger(__name__)


class CIFAR10ShardDescriptor(ShardDescriptor):
"""
CIFAR10 Shard Descriptor

This example is based on `tfds` data loader.
Note that the ingestion of any model/task requires an iterable dataloader.
Hence, it is possible to utilize these pipelines without explicit need of a
new interface.
"""

def __init__(
self,
rank_worldsize: str = '1, 1',
**kwargs
) -> None:
"""Download/Prepare CIFAR10 dataset"""
self.rank, self.worldsize = tuple(int(num) for num in rank_worldsize.split(','))

# Load dataset
train_ds, valid_ds = self._download_and_prepare_dataset(self.rank, self.worldsize)

# Set attributes
self._sample_shape = train_ds['image'].shape[1:]
self._target_shape = tf.expand_dims(train_ds['label'], -1).shape[1:]

self.splits = {
'train': train_ds,
'valid': valid_ds
}

def _download_and_prepare_dataset(self, rank: int, worldsize: int) -> Tuple[dict]:
"""
Download, Cache CIFAR10 and prepare `tfds` builder.

Provide `rank` and `worldsize` to virtually split dataset across shards
uniquely for each client for simulation purposes.

Returns:
Tuple (train_dict, test_dict) of dictionary with JAX DeviceArray (image and label)
dict['image'] -> DeviceArray float32
dict['label'] -> DeviceArray int32
{'image' : DeviceArray(...), 'label' : DeviceArray(...)}

"""

dataset_builder = tfds.builder('cifar10')
dataset_builder.download_and_prepare()

datasets = dataset_builder.as_dataset()

train_shard_size = int(len(datasets['train']) / worldsize)
test_shard_size = int(len(datasets['test']) / worldsize)

self.train_segment = f'train[{train_shard_size * (rank - 1)}:{train_shard_size * rank}]'
self.test_segment = f'test[{test_shard_size * (rank - 1)}:{test_shard_size * rank}]'
train_dataset = dataset_builder.as_dataset(split=self.train_segment, batch_size=-1)
test_dataset = dataset_builder.as_dataset(split=self.test_segment, batch_size=-1)
train_ds = tfds.as_numpy(train_dataset)
test_ds = tfds.as_numpy(test_dataset)

train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
train_ds['label'] = jnp.int32(train_ds['label'])
test_ds['label'] = jnp.int32(test_ds['label'])

return train_ds, test_ds

def get_shard_dataset_types(self) -> List[str]:
"""Get available split names"""
return list(self.splits)

def get_split(self, name: str) -> tf.data.Dataset:
"""Return a shard dataset by type."""
if name not in self.splits:
raise Exception(f'Split name `{name}` not found.'
f' Expected one of {list(self.splits.keys())}')
return self.splits[name]

@property
def sample_shape(self) -> List[str]:
"""Return the sample shape info."""
return list(map(str, self._sample_shape))

@property
def target_shape(self) -> List[str]:
"""Return the target shape info."""
return list(map(str, self._target_shape))

@property
def dataset_description(self) -> str:
"""Return the dataset description."""
n_train = len(self.splits['train']['label'])
n_test = len(self.splits['valid']['label'])
return (f'CIFAR10 dataset, Shard Segments {self.train_segment}/{self.test_segment}, '
f'rank/world {self.rank}/{self.worldsize}.'
f'\n num_samples [Train/Valid]: [{n_train}/{n_test}]')
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
params:
cuda_devices: []

optional_plugin_components: {}

shard_descriptor:
template: cifar10_shard_descriptor.CIFAR10ShardDescriptor
params:
rank_worldsize: 1, 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
ENVOY_NAME=$1
ENVOY_CONF=$2

DEFAULT_DEVICE='CPU'

if [[ $DEFAULT_DEVICE == 'CPU' ]]
then
export JAX_PLATFORMS="cpu" # Force XLA to use CPU
export CUDA_VISIBLE_DEVICES='-1' # Force TF to use CPU
else
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export TF_FORCE_GPU_ALLOW_GROWTH=true
fi

fx envoy start -n "$ENVOY_NAME" --disable-tls --envoy-config-path "$ENVOY_CONF" -dh localhost -dp 50055
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax
jaxlib
tensorflow==2.13
tensorflow-datasets==4.6.0
Loading