-
Notifications
You must be signed in to change notification settings - Fork 210
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Fedcurv example using TaskRunner API
Signed-off-by: Ynon Flum <[email protected]>
- Loading branch information
1 parent
44614b3
commit 85a1a3d
Showing
10 changed files
with
434 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Pytorch CNN Histology Dataset Training with Fedcurv aggregation | ||
The example code in this directory is used to train a Convolutional Neural Network using the Colorectal Histology dataset. | ||
It uses the Pytorch framework and OpenFL's TaskTunner API. | ||
The federation aggregates intermediate models using the [Fedcurv](https://arxiv.org/pdf/1910.07796) | ||
aggregation algorithm, which performs well (Compared to [FedAvg](https://arxiv.org/abs/2104.11375)) when the datasets are not independent and identically distributed (IID) among collaborators. | ||
|
||
Note that this example is similar to the one present in the `torch_cnn_histology` directory and is here to demonstrate the usage of a different aggregation algorithm using OpenFL's Taskrunner API. | ||
|
||
The differenece between the two examples lies both in the `PyTorchCNNWithFedCurv` class which is used to define a stateful training method which uses an existing `FedCurv` object, | ||
and in the `plan.yaml` file in which the training task is explicitly defined with a non-default aggregation method - `FedCurvWeightedAverage`. | ||
|
||
## Running an example federation | ||
The following instructions can be used to run the federation: | ||
``` | ||
# Copy the workspace template, create collaborators and aggregator | ||
fx workspace create --template torch_cnn_histology_fedcurv --prefix fedcurv | ||
cd fedcurv fx workspace certify | ||
fx aggregator generate-cert-request | ||
fx aggregator certify --silent | ||
fx plan initialize | ||
fx collaborator create -n collaborator1 -d 1 | ||
fx collaborator generate-cert-request -n collaborator1 | ||
fx collaborator certify -n collaborator1 --silent | ||
fx collaborator create -n collaborator2 -d 2 | ||
fx collaborator generate-cert-request -n collaborator2 | ||
fx collaborator certify -n collaborator2 --silent | ||
# Run aggregator and collaborators | ||
fx aggregator start & | ||
fx collaborator start -n collaborator1 & | ||
fx collaborator start -n collaborator2 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (C) 2020-2021 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
collaborators: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (C) 2020-2021 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
one,1 | ||
two,2 |
48 changes: 48 additions & 0 deletions
48
openfl-workspace/torch_cnn_histology_fedcurv/plan/plan.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (C) 2020-2021 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
aggregator : | ||
defaults : plan/defaults/aggregator.yaml | ||
template : openfl.component.Aggregator | ||
settings : | ||
init_state_path : save/torch_cnn_histology_init.pbuf | ||
best_state_path : save/torch_cnn_histology_best.pbuf | ||
last_state_path : save/torch_cnn_histology_last.pbuf | ||
rounds_to_train : 20 | ||
|
||
collaborator : | ||
defaults : plan/defaults/collaborator.yaml | ||
template : openfl.component.Collaborator | ||
settings : | ||
delta_updates : false | ||
opt_treatment : RESET | ||
|
||
data_loader : | ||
template : src.dataloader.PyTorchHistologyInMemory | ||
settings : | ||
collaborator_count : 2 | ||
data_group_name : histology | ||
batch_size : 32 | ||
|
||
task_runner: | ||
defaults : plan/defaults/task_runner.yaml | ||
template: src.taskrunner.PyTorchCNNWithFedCurv | ||
|
||
network: | ||
defaults: plan/defaults/network.yaml | ||
|
||
tasks: | ||
defaults: plan/defaults/tasks_torch.yaml | ||
train: | ||
function: train_task | ||
aggregation_type: | ||
template: openfl.interface.aggregation_functions.FedCurvWeightedAverage | ||
kwargs: | ||
metrics: | ||
- loss | ||
|
||
assigner: | ||
defaults: plan/defaults/assigner.yaml | ||
|
||
compression_pipeline : | ||
defaults : plan/defaults/compression_pipeline.yaml |
2 changes: 2 additions & 0 deletions
2
openfl-workspace/torch_cnn_histology_fedcurv/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch==2.3.1 | ||
torchvision==0.18.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (C) 2020-2021 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""You may copy this file as the starting point of your own model.""" |
180 changes: 180 additions & 0 deletions
180
openfl-workspace/torch_cnn_histology_fedcurv/src/dataloader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# Copyright (C) 2020-2021 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""You may copy this file as the starting point of your own model.""" | ||
|
||
from collections.abc import Iterable | ||
from logging import getLogger | ||
from os import makedirs | ||
from pathlib import Path | ||
from urllib.request import urlretrieve | ||
from zipfile import ZipFile | ||
|
||
from openfl.federated import PyTorchDataLoader | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import random_split | ||
from torchvision.datasets import ImageFolder | ||
from torchvision.transforms import ToTensor | ||
from tqdm import tqdm | ||
|
||
from openfl.utilities import validate_file_hash | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class PyTorchHistologyInMemory(PyTorchDataLoader): | ||
"""PyTorch data loader for Histology dataset.""" | ||
|
||
def __init__(self, data_path, batch_size, **kwargs): | ||
"""Instantiate the data object. | ||
Args: | ||
data_path: The file path to the data | ||
batch_size: The batch size of the data loader | ||
**kwargs: Additional arguments, passed to super init | ||
and load_mnist_shard | ||
""" | ||
super().__init__(batch_size, random_seed=0, **kwargs) | ||
|
||
_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard( | ||
shard_num=int(data_path), **kwargs) | ||
|
||
self.X_train = X_train | ||
self.y_train = y_train | ||
self.X_valid = X_valid | ||
self.y_valid = y_valid | ||
|
||
self.num_classes = num_classes | ||
|
||
|
||
class HistologyDataset(ImageFolder): | ||
"""Colorectal Histology Dataset.""" | ||
|
||
URL = ('https://zenodo.org/record/53169/files/Kather_' | ||
'texture_2016_image_tiles_5000.zip?download=1') | ||
FILENAME = 'Kather_texture_2016_image_tiles_5000.zip' | ||
FOLDER_NAME = 'Kather_texture_2016_image_tiles_5000' | ||
ZIP_SHA384 = ('7d86abe1d04e68b77c055820c2a4c582a1d25d2983e38ab724e' | ||
'ac75affce8b7cb2cbf5ba68848dcfd9d84005d87d6790') | ||
DEFAULT_PATH = Path.cwd().absolute() / 'data' | ||
|
||
def __init__(self, root: Path = DEFAULT_PATH, **kwargs) -> None: | ||
"""Initialize.""" | ||
makedirs(root, exist_ok=True) | ||
filepath = root / HistologyDataset.FILENAME | ||
if not filepath.is_file(): | ||
self.pbar = tqdm(total=None) | ||
urlretrieve(HistologyDataset.URL, filepath, self.report_hook) # nosec | ||
validate_file_hash(filepath, HistologyDataset.ZIP_SHA384) | ||
with ZipFile(filepath, 'r') as f: | ||
f.extractall(root) | ||
|
||
super(HistologyDataset, self).__init__(root / HistologyDataset.FOLDER_NAME, **kwargs) | ||
|
||
def report_hook(self, count, block_size, total_size): | ||
"""Update progressbar.""" | ||
if self.pbar.total is None and total_size: | ||
self.pbar.total = total_size | ||
progress_bytes = count * block_size | ||
self.pbar.update(progress_bytes - self.pbar.n) | ||
|
||
def __getitem__(self, index): | ||
"""Allow getting items by slice index.""" | ||
if isinstance(index, Iterable): | ||
return [super(HistologyDataset, self).__getitem__(i) for i in index] | ||
else: | ||
return super(HistologyDataset, self).__getitem__(index) | ||
|
||
|
||
def one_hot(labels, classes): | ||
""" | ||
One Hot encode a vector. | ||
Args: | ||
labels (list): List of labels to onehot encode | ||
classes (int): Total number of categorical classes | ||
Returns: | ||
np.array: Matrix of one-hot encoded labels | ||
""" | ||
return np.eye(classes)[labels] | ||
|
||
|
||
def _load_raw_datashards(shard_num, collaborator_count, train_split_ratio=0.8): | ||
""" | ||
Load the raw data by shard. | ||
Returns tuples of the dataset shard divided into training and validation. | ||
Args: | ||
shard_num (int): The shard number to use | ||
collaborator_count (int): The number of collaborators in the federation | ||
Returns: | ||
2 tuples: (image, label) of the training, validation dataset | ||
""" | ||
dataset = HistologyDataset(transform=ToTensor()) | ||
n_train = int(train_split_ratio * len(dataset)) | ||
n_valid = len(dataset) - n_train | ||
ds_train, ds_val = random_split( | ||
dataset, lengths=[n_train, n_valid], generator=torch.manual_seed(0)) | ||
|
||
# create the shards | ||
X_train, y_train = list(zip(*ds_train[shard_num::collaborator_count])) | ||
X_train, y_train = np.stack(X_train), np.array(y_train) | ||
|
||
X_valid, y_valid = list(zip(*ds_val[shard_num::collaborator_count])) | ||
X_valid, y_valid = np.stack(X_valid), np.array(y_valid) | ||
|
||
return (X_train, y_train), (X_valid, y_valid) | ||
|
||
|
||
def load_histology_shard(shard_num, collaborator_count, | ||
categorical=False, channels_last=False, **kwargs): | ||
""" | ||
Load the Histology dataset. | ||
Args: | ||
shard_num (int): The shard to use from the dataset | ||
collaborator_count (int): The number of collaborators in the federation | ||
categorical (bool): True = convert the labels to one-hot encoded | ||
vectors (Default = True) | ||
channels_last (bool): True = The input images have the channels | ||
last (Default = True) | ||
**kwargs: Additional parameters to pass to the function | ||
Returns: | ||
list: The input shape | ||
int: The number of classes | ||
numpy.ndarray: The training data | ||
numpy.ndarray: The training labels | ||
numpy.ndarray: The validation data | ||
numpy.ndarray: The validation labels | ||
""" | ||
img_rows, img_cols = 150, 150 | ||
num_classes = 8 | ||
|
||
(X_train, y_train), (X_valid, y_valid) = _load_raw_datashards( | ||
shard_num, collaborator_count) | ||
|
||
if channels_last: | ||
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 3) | ||
X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 3) | ||
input_shape = (img_rows, img_cols, 3) | ||
else: | ||
X_train = X_train.reshape(X_train.shape[0], 3, img_rows, img_cols) | ||
X_valid = X_valid.reshape(X_valid.shape[0], 3, img_rows, img_cols) | ||
input_shape = (3, img_rows, img_cols) | ||
|
||
logger.info(f'Histology > X_train Shape : {X_train.shape}') | ||
logger.info(f'Histology > y_train Shape : {y_train.shape}') | ||
logger.info(f'Histology > Train Samples : {X_train.shape[0]}') | ||
logger.info(f'Histology > Valid Samples : {X_valid.shape[0]}') | ||
|
||
if categorical: | ||
# convert class vectors to binary class matrices | ||
y_train = one_hot(y_train, num_classes) | ||
y_valid = one_hot(y_valid, num_classes) | ||
|
||
return input_shape, num_classes, X_train, y_train, X_valid, y_valid |
Oops, something went wrong.