Skip to content

Commit

Permalink
Add Fedcurv example using TaskRunner API
Browse files Browse the repository at this point in the history
Signed-off-by: Ynon Flum <[email protected]>
  • Loading branch information
ynonflumintel committed Dec 16, 2024
1 parent 44614b3 commit 85a1a3d
Show file tree
Hide file tree
Showing 10 changed files with 434 additions and 3 deletions.
34 changes: 34 additions & 0 deletions openfl-workspace/torch_cnn_histology_fedcurv/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Pytorch CNN Histology Dataset Training with Fedcurv aggregation
The example code in this directory is used to train a Convolutional Neural Network using the Colorectal Histology dataset.
It uses the Pytorch framework and OpenFL's TaskTunner API.
The federation aggregates intermediate models using the [Fedcurv](https://arxiv.org/pdf/1910.07796)
aggregation algorithm, which performs well (Compared to [FedAvg](https://arxiv.org/abs/2104.11375)) when the datasets are not independent and identically distributed (IID) among collaborators.

Note that this example is similar to the one present in the `torch_cnn_histology` directory and is here to demonstrate the usage of a different aggregation algorithm using OpenFL's Taskrunner API.

The differenece between the two examples lies both in the `PyTorchCNNWithFedCurv` class which is used to define a stateful training method which uses an existing `FedCurv` object,
and in the `plan.yaml` file in which the training task is explicitly defined with a non-default aggregation method - `FedCurvWeightedAverage`.

## Running an example federation
The following instructions can be used to run the federation:
```
# Copy the workspace template, create collaborators and aggregator
fx workspace create --template torch_cnn_histology_fedcurv --prefix fedcurv
cd fedcurv fx workspace certify
fx aggregator generate-cert-request
fx aggregator certify --silent
fx plan initialize
fx collaborator create -n collaborator1 -d 1
fx collaborator generate-cert-request -n collaborator1
fx collaborator certify -n collaborator1 --silent
fx collaborator create -n collaborator2 -d 2
fx collaborator generate-cert-request -n collaborator2
fx collaborator certify -n collaborator2 --silent
# Run aggregator and collaborators
fx aggregator start &
fx collaborator start -n collaborator1 &
fx collaborator start -n collaborator2
```
4 changes: 4 additions & 0 deletions openfl-workspace/torch_cnn_histology_fedcurv/plan/cols.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (C) 2020-2021 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

collaborators:
5 changes: 5 additions & 0 deletions openfl-workspace/torch_cnn_histology_fedcurv/plan/data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2020-2021 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

one,1
two,2
48 changes: 48 additions & 0 deletions openfl-workspace/torch_cnn_histology_fedcurv/plan/plan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (C) 2020-2021 Intel Corporation
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you.

aggregator :
defaults : plan/defaults/aggregator.yaml
template : openfl.component.Aggregator
settings :
init_state_path : save/torch_cnn_histology_init.pbuf
best_state_path : save/torch_cnn_histology_best.pbuf
last_state_path : save/torch_cnn_histology_last.pbuf
rounds_to_train : 20

collaborator :
defaults : plan/defaults/collaborator.yaml
template : openfl.component.Collaborator
settings :
delta_updates : false
opt_treatment : RESET

data_loader :
template : src.dataloader.PyTorchHistologyInMemory
settings :
collaborator_count : 2
data_group_name : histology
batch_size : 32

task_runner:
defaults : plan/defaults/task_runner.yaml
template: src.taskrunner.PyTorchCNNWithFedCurv

network:
defaults: plan/defaults/network.yaml

tasks:
defaults: plan/defaults/tasks_torch.yaml
train:
function: train_task
aggregation_type:
template: openfl.interface.aggregation_functions.FedCurvWeightedAverage
kwargs:
metrics:
- loss

assigner:
defaults: plan/defaults/assigner.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
2 changes: 2 additions & 0 deletions openfl-workspace/torch_cnn_histology_fedcurv/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch==2.3.1
torchvision==0.18.1
3 changes: 3 additions & 0 deletions openfl-workspace/torch_cnn_histology_fedcurv/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (C) 2020-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""You may copy this file as the starting point of your own model."""
180 changes: 180 additions & 0 deletions openfl-workspace/torch_cnn_histology_fedcurv/src/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (C) 2020-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""

from collections.abc import Iterable
from logging import getLogger
from os import makedirs
from pathlib import Path
from urllib.request import urlretrieve
from zipfile import ZipFile

from openfl.federated import PyTorchDataLoader
import numpy as np
import torch
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from tqdm import tqdm

from openfl.utilities import validate_file_hash

logger = getLogger(__name__)


class PyTorchHistologyInMemory(PyTorchDataLoader):
"""PyTorch data loader for Histology dataset."""

def __init__(self, data_path, batch_size, **kwargs):
"""Instantiate the data object.
Args:
data_path: The file path to the data
batch_size: The batch size of the data loader
**kwargs: Additional arguments, passed to super init
and load_mnist_shard
"""
super().__init__(batch_size, random_seed=0, **kwargs)

_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
shard_num=int(data_path), **kwargs)

self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
self.y_valid = y_valid

self.num_classes = num_classes


class HistologyDataset(ImageFolder):
"""Colorectal Histology Dataset."""

URL = ('https://zenodo.org/record/53169/files/Kather_'
'texture_2016_image_tiles_5000.zip?download=1')
FILENAME = 'Kather_texture_2016_image_tiles_5000.zip'
FOLDER_NAME = 'Kather_texture_2016_image_tiles_5000'
ZIP_SHA384 = ('7d86abe1d04e68b77c055820c2a4c582a1d25d2983e38ab724e'
'ac75affce8b7cb2cbf5ba68848dcfd9d84005d87d6790')
DEFAULT_PATH = Path.cwd().absolute() / 'data'

def __init__(self, root: Path = DEFAULT_PATH, **kwargs) -> None:
"""Initialize."""
makedirs(root, exist_ok=True)
filepath = root / HistologyDataset.FILENAME
if not filepath.is_file():
self.pbar = tqdm(total=None)
urlretrieve(HistologyDataset.URL, filepath, self.report_hook) # nosec
validate_file_hash(filepath, HistologyDataset.ZIP_SHA384)
with ZipFile(filepath, 'r') as f:
f.extractall(root)

super(HistologyDataset, self).__init__(root / HistologyDataset.FOLDER_NAME, **kwargs)

def report_hook(self, count, block_size, total_size):
"""Update progressbar."""
if self.pbar.total is None and total_size:
self.pbar.total = total_size
progress_bytes = count * block_size
self.pbar.update(progress_bytes - self.pbar.n)

def __getitem__(self, index):
"""Allow getting items by slice index."""
if isinstance(index, Iterable):
return [super(HistologyDataset, self).__getitem__(i) for i in index]
else:
return super(HistologyDataset, self).__getitem__(index)


def one_hot(labels, classes):
"""
One Hot encode a vector.
Args:
labels (list): List of labels to onehot encode
classes (int): Total number of categorical classes
Returns:
np.array: Matrix of one-hot encoded labels
"""
return np.eye(classes)[labels]


def _load_raw_datashards(shard_num, collaborator_count, train_split_ratio=0.8):
"""
Load the raw data by shard.
Returns tuples of the dataset shard divided into training and validation.
Args:
shard_num (int): The shard number to use
collaborator_count (int): The number of collaborators in the federation
Returns:
2 tuples: (image, label) of the training, validation dataset
"""
dataset = HistologyDataset(transform=ToTensor())
n_train = int(train_split_ratio * len(dataset))
n_valid = len(dataset) - n_train
ds_train, ds_val = random_split(
dataset, lengths=[n_train, n_valid], generator=torch.manual_seed(0))

# create the shards
X_train, y_train = list(zip(*ds_train[shard_num::collaborator_count]))
X_train, y_train = np.stack(X_train), np.array(y_train)

X_valid, y_valid = list(zip(*ds_val[shard_num::collaborator_count]))
X_valid, y_valid = np.stack(X_valid), np.array(y_valid)

return (X_train, y_train), (X_valid, y_valid)


def load_histology_shard(shard_num, collaborator_count,
categorical=False, channels_last=False, **kwargs):
"""
Load the Histology dataset.
Args:
shard_num (int): The shard to use from the dataset
collaborator_count (int): The number of collaborators in the federation
categorical (bool): True = convert the labels to one-hot encoded
vectors (Default = True)
channels_last (bool): True = The input images have the channels
last (Default = True)
**kwargs: Additional parameters to pass to the function
Returns:
list: The input shape
int: The number of classes
numpy.ndarray: The training data
numpy.ndarray: The training labels
numpy.ndarray: The validation data
numpy.ndarray: The validation labels
"""
img_rows, img_cols = 150, 150
num_classes = 8

(X_train, y_train), (X_valid, y_valid) = _load_raw_datashards(
shard_num, collaborator_count)

if channels_last:
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 3)
X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 3)
input_shape = (img_rows, img_cols, 3)
else:
X_train = X_train.reshape(X_train.shape[0], 3, img_rows, img_cols)
X_valid = X_valid.reshape(X_valid.shape[0], 3, img_rows, img_cols)
input_shape = (3, img_rows, img_cols)

logger.info(f'Histology > X_train Shape : {X_train.shape}')
logger.info(f'Histology > y_train Shape : {y_train.shape}')
logger.info(f'Histology > Train Samples : {X_train.shape[0]}')
logger.info(f'Histology > Valid Samples : {X_valid.shape[0]}')

if categorical:
# convert class vectors to binary class matrices
y_train = one_hot(y_train, num_classes)
y_valid = one_hot(y_valid, num_classes)

return input_shape, num_classes, X_train, y_train, X_valid, y_valid
Loading

0 comments on commit 85a1a3d

Please sign in to comment.