From 6d8a219438b3703552cfbdb640509a67da2050ea Mon Sep 17 00:00:00 2001 From: Kevin Ta <116312994+kta-intel@users.noreply.github.com> Date: Tue, 21 May 2024 14:59:25 -0400 Subject: [PATCH] updates to torch_cnn_histology openfl-workspace for Task Runner API (#976) * update torch_cnn_histology openfl-workspace for taskrunner api Signed-off-by: kta-intel * fix lint error Signed-off-by: kta-intel * upgrade torch and torchvision to latest. fix lint Signed-off-by: kta-intel --------- Signed-off-by: kta-intel --- .../torch_cnn_histology/plan/plan.yaml | 4 +- .../torch_cnn_histology/requirements.txt | 4 +- .../src/{histology_utils.py => dataloader.py} | 26 ++++ .../src/pthistology_inmemory.py | 32 ---- .../torch_cnn_histology/src/requirements.txt | 7 - .../torch_cnn_histology/src/taskrunner.py | 140 ++++++++++++++++++ 6 files changed, 170 insertions(+), 43 deletions(-) rename openfl-workspace/torch_cnn_histology/src/{histology_utils.py => dataloader.py} (86%) delete mode 100644 openfl-workspace/torch_cnn_histology/src/pthistology_inmemory.py delete mode 100644 openfl-workspace/torch_cnn_histology/src/requirements.txt create mode 100644 openfl-workspace/torch_cnn_histology/src/taskrunner.py diff --git a/openfl-workspace/torch_cnn_histology/plan/plan.yaml b/openfl-workspace/torch_cnn_histology/plan/plan.yaml index ed7fdf13c0..18a3c0cde5 100644 --- a/openfl-workspace/torch_cnn_histology/plan/plan.yaml +++ b/openfl-workspace/torch_cnn_histology/plan/plan.yaml @@ -18,7 +18,7 @@ collaborator : opt_treatment : RESET data_loader : - template : src.pthistology_inmemory.PyTorchHistologyInMemory + template : src.dataloader.PyTorchHistologyInMemory settings : collaborator_count : 2 data_group_name : histology @@ -26,7 +26,7 @@ data_loader : task_runner: defaults : plan/defaults/task_runner.yaml - template: src.pt_cnn.PyTorchCNN + template: src.taskrunner.PyTorchCNN network: defaults: plan/defaults/network.yaml diff --git a/openfl-workspace/torch_cnn_histology/requirements.txt b/openfl-workspace/torch_cnn_histology/requirements.txt index 7e3566ee3b..c44f6b0daa 100644 --- a/openfl-workspace/torch_cnn_histology/requirements.txt +++ b/openfl-workspace/torch_cnn_histology/requirements.txt @@ -1,2 +1,2 @@ -torchvision==0.14.1 -f https://download.pytorch.org/whl/torch_stable.html -torch==1.13.1 -f https://download.pytorch.org/whl/torch_stable.html +torchvision==0.18.0 -f https://download.pytorch.org/whl/torch_stable.html +torch==2.3.0 -f https://download.pytorch.org/whl/torch_stable.html \ No newline at end of file diff --git a/openfl-workspace/torch_cnn_histology/src/histology_utils.py b/openfl-workspace/torch_cnn_histology/src/dataloader.py similarity index 86% rename from openfl-workspace/torch_cnn_histology/src/histology_utils.py rename to openfl-workspace/torch_cnn_histology/src/dataloader.py index 233961c497..dc7dd4e0c6 100644 --- a/openfl-workspace/torch_cnn_histology/src/histology_utils.py +++ b/openfl-workspace/torch_cnn_histology/src/dataloader.py @@ -10,6 +10,7 @@ from urllib.request import urlretrieve from zipfile import ZipFile +from openfl.federated import PyTorchDataLoader import numpy as np import torch from torch.utils.data import random_split @@ -22,6 +23,31 @@ logger = getLogger(__name__) +class PyTorchHistologyInMemory(PyTorchDataLoader): + """PyTorch data loader for Histology dataset.""" + + def __init__(self, data_path, batch_size, **kwargs): + """Instantiate the data object. + + Args: + data_path: The file path to the data + batch_size: The batch size of the data loader + **kwargs: Additional arguments, passed to super init + and load_mnist_shard + """ + super().__init__(batch_size, random_seed=0, **kwargs) + + _, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard( + shard_num=int(data_path), **kwargs) + + self.X_train = X_train + self.y_train = y_train + self.X_valid = X_valid + self.y_valid = y_valid + + self.num_classes = num_classes + + class HistologyDataset(ImageFolder): """Colorectal Histology Dataset.""" diff --git a/openfl-workspace/torch_cnn_histology/src/pthistology_inmemory.py b/openfl-workspace/torch_cnn_histology/src/pthistology_inmemory.py deleted file mode 100644 index 3ebecc2123..0000000000 --- a/openfl-workspace/torch_cnn_histology/src/pthistology_inmemory.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (C) 2020-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -"""You may copy this file as the starting point of your own model.""" - -from openfl.federated import PyTorchDataLoader -from .histology_utils import load_histology_shard - - -class PyTorchHistologyInMemory(PyTorchDataLoader): - """PyTorch data loader for Histology dataset.""" - - def __init__(self, data_path, batch_size, **kwargs): - """Instantiate the data object. - - Args: - data_path: The file path to the data - batch_size: The batch size of the data loader - **kwargs: Additional arguments, passed to super init - and load_mnist_shard - """ - super().__init__(batch_size, random_seed=0, **kwargs) - - _, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard( - shard_num=int(data_path), **kwargs) - - self.X_train = X_train - self.y_train = y_train - self.X_valid = X_valid - self.y_valid = y_valid - - self.num_classes = num_classes diff --git a/openfl-workspace/torch_cnn_histology/src/requirements.txt b/openfl-workspace/torch_cnn_histology/src/requirements.txt deleted file mode 100644 index 5bb0f7b988..0000000000 --- a/openfl-workspace/torch_cnn_histology/src/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -torchvision==0.14.1 -Pillow==10.3.0 -tqdm==4.66.3 -numpy==1.22.2 -torch>=1.13.1 # not directly required, pinned by Snyk to avoid a vulnerability -setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability -wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/openfl-workspace/torch_cnn_histology/src/taskrunner.py b/openfl-workspace/torch_cnn_histology/src/taskrunner.py new file mode 100644 index 0000000000..933c1c1e1e --- /dev/null +++ b/openfl-workspace/torch_cnn_histology/src/taskrunner.py @@ -0,0 +1,140 @@ +# Copyright (C) 2020-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from typing import Iterator, Tuple + +from openfl.federated import PyTorchTaskRunner +from openfl.utilities import Metric + + +class PyTorchCNN(PyTorchTaskRunner): + """ + Simple CNN for classification. + + PyTorchTaskRunner inherits from nn.module, so you can define your model + in the same way that you would for PyTorch + """ + + def __init__(self, device="cpu", **kwargs): + """Initialize. + + Args: + device: The hardware device to use for training (Default = "cpu") + **kwargs: Additional arguments to pass to the function + + """ + super().__init__(device=device, **kwargs) + + # Define the model + channel = self.data_loader.get_feature_shape()[0] # (channel, dim1, dim2) + self.conv1 = nn.Conv2d(channel, 16, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv5 = nn.Conv2d(128 + 32, 256, kernel_size=3, stride=1, padding=1) + self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv7 = nn.Conv2d(512 + 128 + 32, 256, kernel_size=3, stride=1, padding=1) + self.conv8 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.fc1 = nn.Linear(1184 * 9 * 9, 128) + self.fc2 = nn.Linear(128, 8) + + # `self.optimizer` must be set for optimizer weights to be federated + self.optimizer = optim.Adam(self.parameters(), lr=1e-3) + + # Set the loss function + self.loss_fn = F.cross_entropy + + def forward(self, x): + """Forward pass of the model. + + Args: + x: Data input to the model for the forward pass + """ + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + maxpool = F.max_pool2d(x, 2, 2) + + x = F.relu(self.conv3(maxpool)) + x = F.relu(self.conv4(x)) + concat = torch.cat([maxpool, x], dim=1) + maxpool = F.max_pool2d(concat, 2, 2) + + x = F.relu(self.conv5(maxpool)) + x = F.relu(self.conv6(x)) + concat = torch.cat([maxpool, x], dim=1) + maxpool = F.max_pool2d(concat, 2, 2) + + x = F.relu(self.conv7(maxpool)) + x = F.relu(self.conv8(x)) + concat = torch.cat([maxpool, x], dim=1) + maxpool = F.max_pool2d(concat, 2, 2) + + x = maxpool.flatten(start_dim=1) + x = F.dropout(self.fc1(x), p=0.5) + x = self.fc2(x) + return x + + def train_( + self, train_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]] + ) -> Metric: + """Train single epoch. + + Override this function in order to use custom training. + + Args: + batch_generator: Train dataset batch generator. Yields (samples, targets) tuples of + size = `self.data_loader.batch_size`. + Returns: + Metric: An object containing name and np.ndarray value. + """ + losses = [] + for data, target in train_dataloader: + data, target = torch.tensor(data).to(self.device), torch.tensor(target).to( + self.device + ) + self.optimizer.zero_grad() + output = self(data) + loss = self.loss_fn(output, target) + loss.backward() + self.optimizer.step() + losses.append(loss.detach().cpu().numpy()) + loss = np.mean(losses) + return Metric(name=self.loss_fn.__name__, value=np.array(loss)) + + def validate_( + self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]] + ) -> Metric: + """ + Perform validation on PyTorch Model + + Override this function for your own custom validation function + + Args: + validation_data_loader: Validation dataset batch generator. + Yields (samples, targets) tuples. + Returns: + Metric: An object containing name and np.ndarray value + """ + + total_samples = 0 + val_score = 0 + with torch.no_grad(): + for data, target in validation_dataloader: + samples = target.shape[0] + total_samples += samples + data, target = torch.tensor(data).to(self.device), torch.tensor( + target + ).to(self.device, dtype=torch.int64) + output = self(data) + # get the index of the max log-probability + pred = output.argmax(dim=1) + val_score += pred.eq(target).sum().cpu().numpy() + + accuracy = val_score / total_samples + return Metric(name="accuracy", value=np.array(accuracy))