diff --git a/openfl-workspace/torch_cnn_histology_fedcurv/README.md b/openfl-workspace/torch_cnn_histology_fedcurv/README.md new file mode 100644 index 00000000000..de6f18e42d8 --- /dev/null +++ b/openfl-workspace/torch_cnn_histology_fedcurv/README.md @@ -0,0 +1,34 @@ +# Pytorch CNN Histology Dataset Training with Fedcurv aggregation +The example code in this directory is used to train a Convolutional Neural Network using the Colorectal Histology dataset. +It uses the Pytorch framework and OpenFL's TaskTunner API. +The federation aggregates intermediate models using the [Fedcurv](https://arxiv.org/pdf/1910.07796) +aggregation algorithm, which performs well (Compared to [FedAvg](https://arxiv.org/abs/2104.11375)) when the datasets are not independent and identically distributed (IID) among collaborators. + +Note that this example is similar to the one present in the `torch_cnn_histology` directory and is here to demonstrate the usage of a different aggregation algorithm using OpenFL's Taskrunner API. + +The differenece between the two examples lies both in the `PyTorchCNNWithFedCurv` class which is used to define a stateful training method which uses an existing `FedCurv` object, +and in the `plan.yaml` file in which the training task is explicitly defined with a non-default aggregation method - `FedCurvWeightedAverage`. + +## Running an example federation +The following instructions can be used to run the federation: +``` +# Copy the workspace template, create collaborators and aggregator +fx workspace create --template torch_cnn_histology_fedcurv --prefix fedcurv +cd fedcurv fx workspace certify +fx aggregator generate-cert-request +fx aggregator certify --silent +fx plan initialize + +fx collaborator create -n collaborator1 -d 1 +fx collaborator generate-cert-request -n collaborator1 +fx collaborator certify -n collaborator1 --silent + +fx collaborator create -n collaborator2 -d 2 +fx collaborator generate-cert-request -n collaborator2 +fx collaborator certify -n collaborator2 --silent + +# Run aggregator and collaborators +fx aggregator start & +fx collaborator start -n collaborator1 & +fx collaborator start -n collaborator2 +``` \ No newline at end of file diff --git a/openfl-workspace/torch_cnn_histology_fedcurv/plan/cols.yaml b/openfl-workspace/torch_cnn_histology_fedcurv/plan/cols.yaml new file mode 100644 index 00000000000..a6d1b1b922c --- /dev/null +++ b/openfl-workspace/torch_cnn_histology_fedcurv/plan/cols.yaml @@ -0,0 +1,4 @@ +# Copyright (C) 2020-2021 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +collaborators: \ No newline at end of file diff --git a/openfl-workspace/torch_cnn_histology_fedcurv/plan/data.yaml b/openfl-workspace/torch_cnn_histology_fedcurv/plan/data.yaml new file mode 100644 index 00000000000..8e59641e485 --- /dev/null +++ b/openfl-workspace/torch_cnn_histology_fedcurv/plan/data.yaml @@ -0,0 +1,5 @@ +# Copyright (C) 2020-2021 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +one,1 +two,2 diff --git a/openfl-workspace/torch_cnn_histology_fedcurv/plan/plan.yaml b/openfl-workspace/torch_cnn_histology_fedcurv/plan/plan.yaml new file mode 100644 index 00000000000..d4c4184baea --- /dev/null +++ b/openfl-workspace/torch_cnn_histology_fedcurv/plan/plan.yaml @@ -0,0 +1,48 @@ +# Copyright (C) 2020-2021 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/torch_cnn_histology_init.pbuf + best_state_path : save/torch_cnn_histology_best.pbuf + last_state_path : save/torch_cnn_histology_last.pbuf + rounds_to_train : 20 + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + delta_updates : false + opt_treatment : RESET + +data_loader : + template : src.dataloader.PyTorchHistologyInMemory + settings : + collaborator_count : 2 + data_group_name : histology + batch_size : 32 + +task_runner: + defaults : plan/defaults/task_runner.yaml + template: src.taskrunner.PyTorchCNNWithFedCurv + +network: + defaults: plan/defaults/network.yaml + +tasks: + defaults: plan/defaults/tasks_torch.yaml + train: + function: train_task + aggregation_type: + template: openfl.interface.aggregation_functions.FedCurvWeightedAverage + kwargs: + metrics: + - loss + +assigner: + defaults: plan/defaults/assigner.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/torch_cnn_histology_fedcurv/requirements.txt b/openfl-workspace/torch_cnn_histology_fedcurv/requirements.txt new file mode 100644 index 00000000000..58b9f15677f --- /dev/null +++ b/openfl-workspace/torch_cnn_histology_fedcurv/requirements.txt @@ -0,0 +1,2 @@ +torch==2.3.1 +torchvision==0.18.1 diff --git a/openfl-workspace/torch_cnn_histology_fedcurv/src/__init__.py b/openfl-workspace/torch_cnn_histology_fedcurv/src/__init__.py new file mode 100644 index 00000000000..f1410b1298b --- /dev/null +++ b/openfl-workspace/torch_cnn_histology_fedcurv/src/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""You may copy this file as the starting point of your own model.""" diff --git a/openfl-workspace/torch_cnn_histology_fedcurv/src/dataloader.py b/openfl-workspace/torch_cnn_histology_fedcurv/src/dataloader.py new file mode 100644 index 00000000000..dc7dd4e0c69 --- /dev/null +++ b/openfl-workspace/torch_cnn_histology_fedcurv/src/dataloader.py @@ -0,0 +1,180 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +from collections.abc import Iterable +from logging import getLogger +from os import makedirs +from pathlib import Path +from urllib.request import urlretrieve +from zipfile import ZipFile + +from openfl.federated import PyTorchDataLoader +import numpy as np +import torch +from torch.utils.data import random_split +from torchvision.datasets import ImageFolder +from torchvision.transforms import ToTensor +from tqdm import tqdm + +from openfl.utilities import validate_file_hash + +logger = getLogger(__name__) + + +class PyTorchHistologyInMemory(PyTorchDataLoader): + """PyTorch data loader for Histology dataset.""" + + def __init__(self, data_path, batch_size, **kwargs): + """Instantiate the data object. + + Args: + data_path: The file path to the data + batch_size: The batch size of the data loader + **kwargs: Additional arguments, passed to super init + and load_mnist_shard + """ + super().__init__(batch_size, random_seed=0, **kwargs) + + _, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard( + shard_num=int(data_path), **kwargs) + + self.X_train = X_train + self.y_train = y_train + self.X_valid = X_valid + self.y_valid = y_valid + + self.num_classes = num_classes + + +class HistologyDataset(ImageFolder): + """Colorectal Histology Dataset.""" + + URL = ('https://zenodo.org/record/53169/files/Kather_' + 'texture_2016_image_tiles_5000.zip?download=1') + FILENAME = 'Kather_texture_2016_image_tiles_5000.zip' + FOLDER_NAME = 'Kather_texture_2016_image_tiles_5000' + ZIP_SHA384 = ('7d86abe1d04e68b77c055820c2a4c582a1d25d2983e38ab724e' + 'ac75affce8b7cb2cbf5ba68848dcfd9d84005d87d6790') + DEFAULT_PATH = Path.cwd().absolute() / 'data' + + def __init__(self, root: Path = DEFAULT_PATH, **kwargs) -> None: + """Initialize.""" + makedirs(root, exist_ok=True) + filepath = root / HistologyDataset.FILENAME + if not filepath.is_file(): + self.pbar = tqdm(total=None) + urlretrieve(HistologyDataset.URL, filepath, self.report_hook) # nosec + validate_file_hash(filepath, HistologyDataset.ZIP_SHA384) + with ZipFile(filepath, 'r') as f: + f.extractall(root) + + super(HistologyDataset, self).__init__(root / HistologyDataset.FOLDER_NAME, **kwargs) + + def report_hook(self, count, block_size, total_size): + """Update progressbar.""" + if self.pbar.total is None and total_size: + self.pbar.total = total_size + progress_bytes = count * block_size + self.pbar.update(progress_bytes - self.pbar.n) + + def __getitem__(self, index): + """Allow getting items by slice index.""" + if isinstance(index, Iterable): + return [super(HistologyDataset, self).__getitem__(i) for i in index] + else: + return super(HistologyDataset, self).__getitem__(index) + + +def one_hot(labels, classes): + """ + One Hot encode a vector. + + Args: + labels (list): List of labels to onehot encode + classes (int): Total number of categorical classes + + Returns: + np.array: Matrix of one-hot encoded labels + """ + return np.eye(classes)[labels] + + +def _load_raw_datashards(shard_num, collaborator_count, train_split_ratio=0.8): + """ + Load the raw data by shard. + + Returns tuples of the dataset shard divided into training and validation. + + Args: + shard_num (int): The shard number to use + collaborator_count (int): The number of collaborators in the federation + + Returns: + 2 tuples: (image, label) of the training, validation dataset + """ + dataset = HistologyDataset(transform=ToTensor()) + n_train = int(train_split_ratio * len(dataset)) + n_valid = len(dataset) - n_train + ds_train, ds_val = random_split( + dataset, lengths=[n_train, n_valid], generator=torch.manual_seed(0)) + + # create the shards + X_train, y_train = list(zip(*ds_train[shard_num::collaborator_count])) + X_train, y_train = np.stack(X_train), np.array(y_train) + + X_valid, y_valid = list(zip(*ds_val[shard_num::collaborator_count])) + X_valid, y_valid = np.stack(X_valid), np.array(y_valid) + + return (X_train, y_train), (X_valid, y_valid) + + +def load_histology_shard(shard_num, collaborator_count, + categorical=False, channels_last=False, **kwargs): + """ + Load the Histology dataset. + + Args: + shard_num (int): The shard to use from the dataset + collaborator_count (int): The number of collaborators in the federation + categorical (bool): True = convert the labels to one-hot encoded + vectors (Default = True) + channels_last (bool): True = The input images have the channels + last (Default = True) + **kwargs: Additional parameters to pass to the function + + Returns: + list: The input shape + int: The number of classes + numpy.ndarray: The training data + numpy.ndarray: The training labels + numpy.ndarray: The validation data + numpy.ndarray: The validation labels + """ + img_rows, img_cols = 150, 150 + num_classes = 8 + + (X_train, y_train), (X_valid, y_valid) = _load_raw_datashards( + shard_num, collaborator_count) + + if channels_last: + X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 3) + X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 3) + input_shape = (img_rows, img_cols, 3) + else: + X_train = X_train.reshape(X_train.shape[0], 3, img_rows, img_cols) + X_valid = X_valid.reshape(X_valid.shape[0], 3, img_rows, img_cols) + input_shape = (3, img_rows, img_cols) + + logger.info(f'Histology > X_train Shape : {X_train.shape}') + logger.info(f'Histology > y_train Shape : {y_train.shape}') + logger.info(f'Histology > Train Samples : {X_train.shape[0]}') + logger.info(f'Histology > Valid Samples : {X_valid.shape[0]}') + + if categorical: + # convert class vectors to binary class matrices + y_train = one_hot(y_train, num_classes) + y_valid = one_hot(y_valid, num_classes) + + return input_shape, num_classes, X_train, y_train, X_valid, y_valid diff --git a/openfl-workspace/torch_cnn_histology_fedcurv/src/taskrunner.py b/openfl-workspace/torch_cnn_histology_fedcurv/src/taskrunner.py new file mode 100644 index 00000000000..4047dfddfe4 --- /dev/null +++ b/openfl-workspace/torch_cnn_histology_fedcurv/src/taskrunner.py @@ -0,0 +1,146 @@ +# Copyright (C) 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from typing import Iterator, Tuple +from openfl.utilities.fedcurv.torch.fedcurv import FedCurv + +from openfl.federated import PyTorchTaskRunner +from openfl.utilities import Metric + + +class PyTorchCNNWithFedCurv(PyTorchTaskRunner): + """ + Simple CNN for classification. + + PyTorchTaskRunner inherits from nn.module, so you can define your model + in the same way that you would for PyTorch + """ + + def __init__(self, device="cpu", **kwargs): + """Initialize. + + Args: + device: The hardware device to use for training (Default = "cpu") + **kwargs: Additional arguments to pass to the function + + """ + super().__init__(device=device, **kwargs) + + # Define the model + channel = self.data_loader.get_feature_shape()[0] # (channel, dim1, dim2) + self.conv1 = nn.Conv2d(channel, 16, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv5 = nn.Conv2d(128 + 32, 256, kernel_size=3, stride=1, padding=1) + self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv7 = nn.Conv2d(512 + 128 + 32, 256, kernel_size=3, stride=1, padding=1) + self.conv8 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.fc1 = nn.Linear(1184 * 9 * 9, 128) + self.fc2 = nn.Linear(128, 8) + + # `self.optimizer` must be set for optimizer weights to be federated + self.optimizer = optim.Adam(self.parameters(), lr=1e-3) + self._fedcurv = FedCurv(self, importance=1e7) + + # Set the loss function + self.loss_fn = F.cross_entropy + + def forward(self, x): + """Forward pass of the model. + + Args: + x: Data input to the model for the forward pass + """ + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + maxpool = F.max_pool2d(x, 2, 2) + + x = F.relu(self.conv3(maxpool)) + x = F.relu(self.conv4(x)) + concat = torch.cat([maxpool, x], dim=1) + maxpool = F.max_pool2d(concat, 2, 2) + + x = F.relu(self.conv5(maxpool)) + x = F.relu(self.conv6(x)) + concat = torch.cat([maxpool, x], dim=1) + maxpool = F.max_pool2d(concat, 2, 2) + + x = F.relu(self.conv7(maxpool)) + x = F.relu(self.conv8(x)) + concat = torch.cat([maxpool, x], dim=1) + maxpool = F.max_pool2d(concat, 2, 2) + + x = maxpool.flatten(start_dim=1) + x = F.dropout(self.fc1(x), p=0.5) + x = self.fc2(x) + return x + + def train_( + self, train_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]] + ) -> Metric: + """Train single epoch. + + Override this function in order to use custom training. + + Args: + batch_generator: Train dataset batch generator. Yields (samples, targets) tuples of + size = `self.data_loader.batch_size`. + Returns: + Metric: An object containing name and np.ndarray value. + """ + losses = [] + model = self + self._fedcurv.on_train_begin(model) + + for data, target in train_dataloader: + data, target = torch.tensor(data).to(self.device), torch.tensor(target).to( + self.device + ) + self.optimizer.zero_grad() + output = self(data) + loss = self.loss_fn(output, target) + self._fedcurv.get_penalty(model) + loss.backward() + self.optimizer.step() + losses.append(loss.detach().cpu().numpy()) + break + loss = np.mean(losses) + return Metric(name=self.loss_fn.__name__, value=np.array(loss)) + + def validate_( + self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]] + ) -> Metric: + """ + Perform validation on PyTorch Model + + Override this function for your own custom validation function + + Args: + validation_data_loader: Validation dataset batch generator. + Yields (samples, targets) tuples. + Returns: + Metric: An object containing name and np.ndarray value + """ + + total_samples = 0 + val_score = 0 + with torch.no_grad(): + for data, target in validation_dataloader: + samples = target.shape[0] + total_samples += samples + data, target = torch.tensor(data).to(self.device), torch.tensor( + target + ).to(self.device, dtype=torch.int64) + output = self(data) + # get the index of the max log-probability + pred = output.argmax(dim=1) + val_score += pred.eq(target).sum().cpu().numpy() + + accuracy = val_score / total_samples + return Metric(name="accuracy", value=np.array(accuracy)) diff --git a/openfl/federated/task/fl_model.py b/openfl/federated/task/fl_model.py index 70bc537324b..eb7365225f9 100644 --- a/openfl/federated/task/fl_model.py +++ b/openfl/federated/task/fl_model.py @@ -32,6 +32,8 @@ class FederatedModel(TaskRunner): pytorch). tensor_dict_split_fn_kwargs (dict): Keyword arguments for the tensor dict split function. + data_loader (FederatedDataSet): A dataset to distribute among the collaborators, + see TaskRunner for more details """ def __init__(self, build_model, optimizer=None, loss_fn=None, **kwargs): @@ -75,10 +77,17 @@ def __init__(self, build_model, optimizer=None, loss_fn=None, **kwargs): self.runner.validate = lambda *args, **kwargs: build_model.validate( self.runner, *args, **kwargs ) + if hasattr(self.model, "train_epoch"): self.runner.train_epoch = lambda *args, **kwargs: build_model.train_epoch( self.runner, *args, **kwargs ) + + # Used to hook the training function when debugging locally + if hasattr(self.model, "train_"): + self.runner.train_ = lambda *args, **kwargs: build_model.train_( + self.runner, *args, **kwargs + ) self.runner.model = self.model self.runner.optimizer = self.optimizer self.loss_fn = loss_fn diff --git a/openfl/interface/workspace.py b/openfl/interface/workspace.py index 52129a967e0..d3cb1713c55 100644 --- a/openfl/interface/workspace.py +++ b/openfl/interface/workspace.py @@ -79,10 +79,10 @@ def create_temp(prefix, template): template: The template to use for creating the workspace. """ - echo("Creating Workspace Templates") - + src = template if os.path.isabs(template) else WORKSPACE / template + echo(f"Creating Workspace Templates from {src} in {prefix}") shutil.copytree( - src=WORKSPACE / template, + src=src, dst=prefix, dirs_exist_ok=True, ignore=shutil.ignore_patterns("__pycache__"),