From c90953c8531489b19656e632b6410276ced9a64d Mon Sep 17 00:00:00 2001 From: "Spyridon (Spyros) Bakas" Date: Thu, 14 Dec 2023 13:35:29 -0500 Subject: [PATCH 1/3] Update README.md (#902) Made changes to reflect the transition of the Bakas' group from UPenn to IU. --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a6f2779eff..8f26ba6e04 100644 --- a/README.md +++ b/README.md @@ -67,12 +67,11 @@ OpenFL supports training with TensorFlow 2+ or PyTorch 1.3+ which should be inst ### Background -OpenFL builds on a collaboration between Intel and the University of Pennsylvania (UPenn) to develop the [Federated Tumor Segmentation (FeTS, www.fets.ai)](https://www.fets.ai/) platform (grant award number: U01-CA242871). +OpenFL builds on a collaboration between Intel and the Bakas lab at the University of Pennsylvania (UPenn) to develop the [Federated Tumor Segmentation (FeTS, www.fets.ai)](https://www.fets.ai/) platform (grant award number: U01-CA242871). -The grant for FeTS was awarded to the [Center for Biomedical Image Computing and Analytics (CBICA)](https://www.cbica.upenn.edu/) at UPenn (PI: S. Bakas) from the [Informatics Technology for Cancer Research (ITCR)](https://itcr.cancer.gov/) program of the National Cancer Institute (NCI) of the National Institutes of Health (NIH). +The grant for FeTS was awarded from the [Informatics Technology for Cancer Research (ITCR)](https://itcr.cancer.gov/) program of the National Cancer Institute (NCI) of the National Institutes of Health (NIH), to Dr Spyridon Bakas (Principal Investigator) when he was affiliated with the [Center for Biomedical Image Computing and Analytics (CBICA)](https://www.cbica.upenn.edu/) at UPenn and now heading up the [Division of Computational Pathology at Indiana University (IU)](https://medicine.iu.edu/pathology/research/computational-pathology). -FeTS is a real-world medical federated learning platform with international collaborators. The original OpenFederatedLearning project and OpenFL are designed to serve as the backend for the FeTS platform, -and OpenFL developers and researchers continue to work very closely with UPenn on the FeTS project. An example is the [FeTS-AI/Front-End](https://github.com/FETS-AI/Front-End), which integrates UPenn’s medical AI expertise with OpenFL framework to create a federated learning solution for medical imaging. +FeTS is a real-world medical federated learning platform with international collaborators. The original OpenFederatedLearning project and OpenFL are designed to serve as the backend for the FeTS platform, and OpenFL developers and researchers continue to work very closely with IU on the FeTS project. An example is the [FeTS-AI/Front-End](https://github.com/FETS-AI/Front-End), which integrates the group’s medical AI expertise with OpenFL framework to create a federated learning solution for medical imaging. Although initially developed for use in medical imaging, OpenFL designed to be agnostic to the use-case, the industry, and the machine learning framework. From a1275163f4c8d9cf3d4724a13d4f5a9353c19b9b Mon Sep 17 00:00:00 2001 From: Patrick Foley Date: Wed, 3 Jan 2024 14:39:48 -0800 Subject: [PATCH 2/3] Update Failing GaNDLF Test (#906) * Updates GaNDLF prerequisite packages --- .github/workflows/fets-challenge.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/fets-challenge.yml b/.github/workflows/fets-challenge.yml index 9ff94dcb85..ff57453efc 100644 --- a/.github/workflows/fets-challenge.yml +++ b/.github/workflows/fets-challenge.yml @@ -17,14 +17,14 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v3 with: - python-version: "3.8" + python-version: "3.10" - name: Install dependencies run: | python -m pip install --upgrade pip - pip install torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu + pip install torch==2.1.0+cpu torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu pip install . - name: Setup FeTS Challenge Prerequisites uses: actions/checkout@master From f1657abe88632d542504d6d71ca961de9333913f Mon Sep 17 00:00:00 2001 From: Kevin Ta <116312994+kta-intel@users.noreply.github.com> Date: Fri, 5 Jan 2024 18:44:59 -0500 Subject: [PATCH 3/3] incorporating support for federated evaluation (#897) * initialize federated evaluation document Signed-off-by: kta-intel * add workspace for federated evaluation Signed-off-by: kta-intel * Update federated_evaluation.rst Signed-off-by: kta-intel * Update federated_evaluation.rst Signed-off-by: kta-intel * Add comments to plan.yaml Signed-off-by: kta-intel * indexing federated evaluation page under manual Signed-off-by: kta-intel * fixing lint issues Signed-off-by: kta-intel --------- Signed-off-by: kta-intel --- docs/federated_evaluation.rst | 110 +++++++++++ docs/manual.rst | 1 + .../torch_cnn_mnist_fed_eval/.workspace | 2 + .../torch_cnn_mnist_fed_eval/plan/cols.yaml | 5 + .../torch_cnn_mnist_fed_eval/plan/data.yaml | 9 + .../torch_cnn_mnist_fed_eval/plan/defaults | 2 + .../torch_cnn_mnist_fed_eval/plan/plan.yaml | 61 ++++++ .../torch_cnn_mnist_fed_eval/requirements.txt | 4 + .../torch_cnn_mnist_fed_eval/src/__init__.py | 3 + .../src/mnist_utils.py | 115 +++++++++++ .../torch_cnn_mnist_fed_eval/src/pt_cnn.py | 179 ++++++++++++++++++ .../src/ptmnist_inmemory.py | 41 ++++ 12 files changed, 532 insertions(+) create mode 100644 docs/federated_evaluation.rst create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/.workspace create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/plan/cols.yaml create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/plan/data.yaml create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/plan/defaults create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/requirements.txt create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/src/__init__.py create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/src/pt_cnn.py create mode 100644 openfl-workspace/torch_cnn_mnist_fed_eval/src/ptmnist_inmemory.py diff --git a/docs/federated_evaluation.rst b/docs/federated_evaluation.rst new file mode 100644 index 0000000000..b38687d269 --- /dev/null +++ b/docs/federated_evaluation.rst @@ -0,0 +1,110 @@ +Federated Evaluation with OpenFL +================================= + +Introduction to Federated Evaluation +------------------------------------- + +Model evaluation is an essential part of the machine learning development cycle. In a traditional centralized learning system, all evaluation data is collected on a localized server. Because of this, centralized evaluation of machine learning models is a fairly straightforward task. However, in a federated learning system, data is distributed across multiple decentralized devices or nodes. In an effort to preserve the security and privacy of the distributed data, it is infeasible to simply aggregate all the data into a centralized system. Federated evaluation offers a solution by assessing the model at the client side and aggregating the accuracy without ever having to share the data. This is crucial for ensuring the model's effectiveness and reliability in diverse and real-world environments while respecting privacy and data locality + +OpenFL's Support for Federated Evaluation +----------------------------------------- + +OpenFL, a flexible framework for Federated Learning, has the capability to perform federated evaluation by modifying the federation plan. In this document, we will show how OpenFL can facilitate this process through its task runner API (aggregator-based workflow), where the model evaluation is distributed across various collaborators before being sent to the aggregator. For the task runner API, this involves minor modifications to the ``plan.yaml`` file, which defines the workflow and tasks for the federation. In particular, the federation plan should be defined to run for one forward pass and perform only aggregated model validation + +In general pipeline is as follows: + +1. **Setup**: Initialize the federation with the modified ``plan.yaml`` set to run for one round and only perform aggregated model validation +2. **Execution**: Run the federation. The model is distributed across collaborators for evaluation. +3. **Evaluation**: Each collaborator evaluates the model on its local data. +4. **Aggregation**: The aggregator collects and aggregates these metrics to assess overall model performance. + +Example Using the Task Runner API (Aggregator-based Workflow) +------------------------------------------------------------------- + +To demonstrate usage of the task runner API (aggregator-based workflow) for federated evaluation, consider the `Hello Federation example `_. This sample script creates a simple federation with two collaborator nodes and one aggregator node, and executes based on a user specified workspace template. We provide a ``torch_cnn_mnist_fed_eval`` template, which is a federated evaluation template adapted from ``torch_cnn_mnist``. + +This script can be directly executed as follows: + +.. code-block:: console + + python test_hello_federation.py --template torch_cnn_mnist_fed_eval + +In order to adapt this template for federated evaluation, the following modifications were made to ``plan.yaml``: + +.. code-block:: yaml + + # Copyright (C) 2020-2023 Intel Corporation + # Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + + aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/torch_cnn_mnist_init.pbuf + best_state_path : save/torch_cnn_mnist_best.pbuf + last_state_path : save/torch_cnn_mnist_last.pbuf + ######################## + rounds_to_train : 1 + ######################## + log_metric_callback : + template : src.mnist_utils.write_metric + + collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + delta_updates : false + opt_treatment : RESET + + data_loader : + defaults : plan/defaults/data_loader.yaml + template : src.ptmnist_inmemory.PyTorchMNISTInMemory + settings : + collaborator_count : 2 + data_group_name : mnist + batch_size : 256 + + task_runner : + defaults : plan/defaults/task_runner.yaml + template : src.pt_cnn.PyTorchCNN + + network : + defaults : plan/defaults/network.yaml + + assigner : + ######################## + template : openfl.component.RandomGroupedAssigner + settings : + task_groups : + - name : validate + percentage : 1.0 + tasks : + - aggregated_model_validation + ######################## + + tasks : + ######################## + aggregated_model_validation: + function : validate + kwargs : + apply : global + metrics : + - acc + ######################## + + compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml + +Key Changes for Federated Evaluation: + +1. **aggregator.settings.rounds_to_train**: Set to 1 +2. **assigner**: Assign to aggregated_model_validation instead of default assignments +3. **tasks**: Set to aggregated_model_validation instead of default tasks + +**Optional**: modify ``src/pt_cnn.py`` to remove optimizer initialization and definition of loss function as these are not needed for evaluation + +This sample script will create a federation based on the `torch_cnn_mnist_fed_eval` template using the `plan.yaml` file defined above, spawning two collaborator nodes and a single aggregator node. The model will be sent to the two collaborator nodes, where each collaborator will perform model validation on its own local data. The accuracy from this model validation will then be send back to the aggregator where it will aggregated into a final accuracy metric. The federation will then be shutdown. + +--- + +Congratulations, you have successfully performed federated evaluation across two decentralized collaborator nodes. diff --git a/docs/manual.rst b/docs/manual.rst index cdab47482a..a0a7254c76 100644 --- a/docs/manual.rst +++ b/docs/manual.rst @@ -35,6 +35,7 @@ Explore new and experimental features: install running_the_federation running_the_federation_with_gandlf + federated_evaluation source/utilities/utilities advanced_topics source/workflow/running_the_federation.tutorial diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/.workspace b/openfl-workspace/torch_cnn_mnist_fed_eval/.workspace new file mode 100644 index 0000000000..3c2c5d08b4 --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/.workspace @@ -0,0 +1,2 @@ +current_plan_name: default + diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/plan/cols.yaml b/openfl-workspace/torch_cnn_mnist_fed_eval/plan/cols.yaml new file mode 100644 index 0000000000..95307de3bc --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/plan/cols.yaml @@ -0,0 +1,5 @@ +# Copyright (C) 2020-2021 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +collaborators: + \ No newline at end of file diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/plan/data.yaml b/openfl-workspace/torch_cnn_mnist_fed_eval/plan/data.yaml new file mode 100644 index 0000000000..cc7e6c9595 --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/plan/data.yaml @@ -0,0 +1,9 @@ +## Copyright (C) 2020-2021 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +# all keys under 'collaborators' corresponds to a specific colaborator name the corresponding dictionary has data_name, data_path pairs. +# Note that in the mnist case we do not store the data locally, and the data_path is used to pass an integer that helps the data object +# construct the shard of the mnist dataset to be use for this collaborator. + +# collaborator_name ,data_directory_path +one,1 \ No newline at end of file diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/plan/defaults b/openfl-workspace/torch_cnn_mnist_fed_eval/plan/defaults new file mode 100644 index 0000000000..fb82f9c5b6 --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/plan/defaults @@ -0,0 +1,2 @@ +../../workspace/plan/defaults + diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml b/openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml new file mode 100644 index 0000000000..2eca7b67c4 --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml @@ -0,0 +1,61 @@ +# Copyright (C) 2020-2023 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/torch_cnn_mnist_init.pbuf + best_state_path : save/torch_cnn_mnist_best.pbuf + last_state_path : save/torch_cnn_mnist_last.pbuf + ######### SET ROUNDS TO 1 ############# + rounds_to_train : 1 + ####################################### + log_metric_callback : + template : src.mnist_utils.write_metric + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + delta_updates : false + opt_treatment : RESET + +data_loader : + defaults : plan/defaults/data_loader.yaml + template : src.ptmnist_inmemory.PyTorchMNISTInMemory + settings : + collaborator_count : 2 + data_group_name : mnist + batch_size : 256 + +task_runner : + defaults : plan/defaults/task_runner.yaml + template : src.pt_cnn.PyTorchCNN + +network : + defaults : plan/defaults/network.yaml + +assigner : + ######### SET ASSIGNER TO ONLY INCLUDE AGGREGATED MODEL VALIDATION ############# + template : openfl.component.RandomGroupedAssigner + settings : + task_groups : + - name : validate + percentage : 1.0 + tasks : + - aggregated_model_validation + ################################################################################ + +tasks : + ######### SET AGGREGATED MODEL VALIDATION AS ONLY TASK ############# + aggregated_model_validation: + function : validate + kwargs : + apply : global + metrics : + - acc + #################################################################### + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/requirements.txt b/openfl-workspace/torch_cnn_mnist_fed_eval/requirements.txt new file mode 100644 index 0000000000..16b349007c --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/requirements.txt @@ -0,0 +1,4 @@ +torch==1.13.1 +torchvision==0.14.1 +tensorboard +wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/src/__init__.py b/openfl-workspace/torch_cnn_mnist_fed_eval/src/__init__.py new file mode 100644 index 0000000000..d5df5b8668 --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/src/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""You may copy this file as the starting point of your own model.""" diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py b/openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py new file mode 100644 index 0000000000..1eccd2a95d --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py @@ -0,0 +1,115 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +from logging import getLogger + +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from torchvision import datasets +from torchvision import transforms + +logger = getLogger(__name__) + +writer = None + + +def get_writer(): + """Create global writer object.""" + global writer + if not writer: + writer = SummaryWriter('./logs/cnn_mnist', flush_secs=5) + + +def write_metric(node_name, task_name, metric_name, metric, round_number): + """Write metric callback.""" + get_writer() + writer.add_scalar(f'{node_name}/{task_name}/{metric_name}', metric, round_number) + + +def one_hot(labels, classes): + """ + One Hot encode a vector. + + Args: + labels (list): List of labels to onehot encode + classes (int): Total number of categorical classes + + Returns: + np.array: Matrix of one-hot encoded labels + """ + return np.eye(classes)[labels] + + +def _load_raw_datashards(shard_num, collaborator_count, transform=None): + """ + Load the raw data by shard. + + Returns tuples of the dataset shard divided into training and validation. + + Args: + shard_num (int): The shard number to use + collaborator_count (int): The number of collaborators in the federation + transform: torchvision.transforms.Transform to apply to images + + Returns: + 2 tuples: (image, label) of the training, validation dataset + """ + train_data, val_data = ( + datasets.MNIST('data', train=train, download=True, transform=transform) + for train in (True, False) + ) + X_train_tot, y_train_tot = train_data.train_data, train_data.train_labels + X_valid_tot, y_valid_tot = val_data.test_data, val_data.test_labels + + # create the shards + shard_num = int(shard_num) + X_train = X_train_tot[shard_num::collaborator_count].unsqueeze(1).float() + y_train = y_train_tot[shard_num::collaborator_count] + + X_valid = X_valid_tot[shard_num::collaborator_count].unsqueeze(1).float() + y_valid = y_valid_tot[shard_num::collaborator_count] + + return (X_train, y_train), (X_valid, y_valid) + + +def load_mnist_shard(shard_num, collaborator_count, + categorical=False, channels_last=True, **kwargs): + """ + Load the MNIST dataset. + + Args: + shard_num (int): The shard to use from the dataset + collaborator_count (int): The number of collaborators in the + federation + categorical (bool): True = convert the labels to one-hot encoded + vectors (Default = True) + channels_last (bool): True = The input images have the channels + last (Default = True) + **kwargs: Additional parameters to pass to the function + + Returns: + list: The input shape + int: The number of classes + numpy.ndarray: The training data + numpy.ndarray: The training labels + numpy.ndarray: The validation data + numpy.ndarray: The validation labels + """ + num_classes = 10 + + (X_train, y_train), (X_valid, y_valid) = _load_raw_datashards( + shard_num, collaborator_count, transform=transforms.ToTensor()) + + logger.info(f'MNIST > X_train Shape : {X_train.shape}') + logger.info(f'MNIST > y_train Shape : {y_train.shape}') + logger.info(f'MNIST > Train Samples : {X_train.shape[0]}') + logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}') + + if categorical: + # convert class vectors to binary class matrices + y_train = one_hot(y_train, num_classes) + y_valid = one_hot(y_valid, num_classes) + + return num_classes, X_train, y_train, X_valid, y_valid diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/src/pt_cnn.py b/openfl-workspace/torch_cnn_mnist_fed_eval/src/pt_cnn.py new file mode 100644 index 0000000000..37d4130a73 --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/src/pt_cnn.py @@ -0,0 +1,179 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import tqdm + +from openfl.federated import PyTorchTaskRunner +from openfl.utilities import TensorKey + + +class PyTorchCNN(PyTorchTaskRunner): + """Simple CNN for classification.""" + + def __init__(self, device='cpu', **kwargs): + """Initialize. + + Args: + data: The data loader class + device: The hardware device to use for training (Default = "cpu") + **kwargs: Additional arguments to pass to the function + + """ + super().__init__(device=device, **kwargs) + + self.num_classes = self.data_loader.num_classes + self.init_network(device=self.device, **kwargs) + self.initialize_tensorkeys_for_functions() + + def init_network(self, + device, + print_model=True, + pool_sqrkernel_size=2, + conv_sqrkernel_size=5, + conv1_channels_out=20, + conv2_channels_out=50, + fc2_insize=500, + **kwargs): + """Create the network (model). + + Args: + device: The hardware device to use for training + print_model (bool): Print the model topology (Default=True) + pool_sqrkernel_size (int): Max pooling kernel size (Default=2), + assumes square 2x2 + conv_sqrkernel_size (int): Convolutional filter size (Default=5), + assumes square 5x5 + conv1_channels_out (int): Number of filters in first + convolutional layer (Default=20) + conv2_channels_out: Number of filters in second convolutional + layer (Default=50) + fc2_insize (int): Number of neurons in the + fully-connected layer (Default = 500) + **kwargs: Additional arguments to pass to the function + + FIXME: We are tracking only side lengths (rather than + length and width) as we are assuming square + shapes for feature and kernels. + In order that all of the input and activation components are + used (not cut off), we rely on a criterion: appropriate integers + are divisible so that all casting to int perfomed below does no + rounding (i.e. all int casting simply converts a float with '0' + in the decimal part to an int.) + + (Note this criterion held for the original input sizes considered + for this model: 28x28 and 32x32 when used with the default values + above) + """ + self.pool_sqrkernel_size = pool_sqrkernel_size + channel = self.data_loader.get_feature_shape()[0] # (channel, dim1, dim2) + self.conv1 = nn.Conv2d(channel, conv1_channels_out, conv_sqrkernel_size, 1) + + # perform some calculations to track the size of the single channel activations + # channels are first for pytorch + conv1_sqrsize_in = self.feature_shape[-1] + conv1_sqrsize_out = conv1_sqrsize_in - (conv_sqrkernel_size - 1) + # a pool operation happens after conv1 out + # (note dependence on 'forward' function below) + conv2_sqrsize_in = int(conv1_sqrsize_out / pool_sqrkernel_size) + + self.conv2 = nn.Conv2d(conv1_channels_out, conv2_channels_out, conv_sqrkernel_size, 1) + + # more tracking of single channel activation size + conv2_sqrsize_out = conv2_sqrsize_in - (conv_sqrkernel_size - 1) + # a pool operation happens after conv2 out + # (note dependence on 'forward' function below) + l0 = int(conv2_sqrsize_out / pool_sqrkernel_size) + self.fc1_insize = l0 * l0 * conv2_channels_out + self.fc1 = nn.Linear(self.fc1_insize, fc2_insize) + self.fc2 = nn.Linear(fc2_insize, self.num_classes) + if print_model: + print(self) + self.to(device) + + def forward(self, x): + """Forward pass of the model. + + Args: + x: Data input to the model for the forward pass + """ + x = F.relu(self.conv1(x)) + pl = self.pool_sqrkernel_size + x = F.max_pool2d(x, pl, pl) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, pl, pl) + x = x.view(-1, self.fc1_insize) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + def validate(self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs): + """Validate. + + Run validation of the model on the local data. + + Args: + col_name: Name of the collaborator + round_num: What round is it + input_tensor_dict: Required input tensors (for model) + use_tqdm (bool): Use tqdm to print a progress bar (Default=True) + + Returns: + global_output_dict: Tensors to send back to the aggregator + local_output_dict: Tensors to maintain in the local TensorDB + + """ + self.rebuild_model(round_num, input_tensor_dict, validation=True) + self.eval() + val_score = 0 + total_samples = 0 + + loader = self.data_loader.get_valid_loader() + if use_tqdm: + loader = tqdm.tqdm(loader, desc='validate') + + with torch.no_grad(): + for data, target in loader: + samples = target.shape[0] + total_samples += samples + data, target = torch.tensor(data).to( + self.device), torch.tensor(target).to( + self.device, dtype=torch.int64) + output = self(data) + # get the index of the max log-probability + pred = output.argmax(dim=1) + val_score += pred.eq(target).sum().cpu().numpy() + + origin = col_name + suffix = 'validate' + if kwargs['apply'] == 'local': + suffix += '_local' + else: + suffix += '_agg' + tags = ('metric', suffix) + # TODO figure out a better way to pass + # in metric for this pytorch validate function + output_tensor_dict = { + TensorKey('acc', origin, round_num, True, tags): + np.array(val_score / total_samples) + } + + # Empty list represents metrics that should only be stored locally + return output_tensor_dict, {} + + def save_native(self, filepath): + """ + Save model in a picked file specified by the filepath. + Uses torch.save(). + + Args: + filepath (string) : Path to pickle file to be + created by pt.save(). + Returns: + None + """ + torch.save(self, filepath) diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/src/ptmnist_inmemory.py b/openfl-workspace/torch_cnn_mnist_fed_eval/src/ptmnist_inmemory.py new file mode 100644 index 0000000000..b570ef8bfb --- /dev/null +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/src/ptmnist_inmemory.py @@ -0,0 +1,41 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +from openfl.federated import PyTorchDataLoader +from .mnist_utils import load_mnist_shard + + +class PyTorchMNISTInMemory(PyTorchDataLoader): + """PyTorch data loader for MNIST dataset.""" + + def __init__(self, data_path, batch_size, **kwargs): + """Instantiate the data object. + + Args: + data_path: The file path to the data + batch_size: The batch size of the data loader + **kwargs: Additional arguments, passed to super + init and load_mnist_shard + """ + super().__init__(batch_size, **kwargs) + + # TODO: We should be downloading the dataset shard into a directory + # TODO: There needs to be a method to ask how many collaborators and + # what index/rank is this collaborator. + # Then we have a way to automatically shard based on rank and size + # of collaborator list. + + num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard( + shard_num=int(data_path), **kwargs) + + self.X_train = X_train + self.y_train = y_train + self.train_loader = self.get_train_loader() + + self.X_valid = X_valid + self.y_valid = y_valid + self.val_loader = self.get_valid_loader() + + self.num_classes = num_classes