Skip to content

Commit

Permalink
update torch_cnn_histology openfl-workspace for taskrunner api
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed May 20, 2024
1 parent 32d5dc4 commit 9c654dd
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 42 deletions.
4 changes: 2 additions & 2 deletions openfl-workspace/torch_cnn_histology/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ collaborator :
opt_treatment : RESET

data_loader :
template : src.pthistology_inmemory.PyTorchHistologyInMemory
template : src.dataloader.PyTorchHistologyInMemory
settings :
collaborator_count : 2
data_group_name : histology
batch_size : 32

task_runner:
defaults : plan/defaults/task_runner.yaml
template: src.pt_cnn.PyTorchCNN
template: src.taskrunner.PyTorchCNN

network:
defaults: plan/defaults/network.yaml
Expand Down
2 changes: 1 addition & 1 deletion openfl-workspace/torch_cnn_histology/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torchvision==0.14.1 -f https://download.pytorch.org/whl/torch_stable.html
torch==1.13.1 -f https://download.pytorch.org/whl/torch_stable.html
torch==1.13.1 -f https://download.pytorch.org/whl/torch_stable.html
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from urllib.request import urlretrieve
from zipfile import ZipFile

from openfl.federated import PyTorchDataLoader
import numpy as np
import torch
from torch.utils.data import random_split
Expand All @@ -22,6 +23,31 @@
logger = getLogger(__name__)


class PyTorchHistologyInMemory(PyTorchDataLoader):
"""PyTorch data loader for Histology dataset."""

def __init__(self, data_path, batch_size, **kwargs):
"""Instantiate the data object.
Args:
data_path: The file path to the data
batch_size: The batch size of the data loader
**kwargs: Additional arguments, passed to super init
and load_mnist_shard
"""
super().__init__(batch_size, random_seed=0, **kwargs)

_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
shard_num=int(data_path), **kwargs)

self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
self.y_valid = y_valid

self.num_classes = num_classes


class HistologyDataset(ImageFolder):
"""Colorectal Histology Dataset."""

Expand Down
32 changes: 0 additions & 32 deletions openfl-workspace/torch_cnn_histology/src/pthistology_inmemory.py

This file was deleted.

7 changes: 0 additions & 7 deletions openfl-workspace/torch_cnn_histology/src/requirements.txt

This file was deleted.

139 changes: 139 additions & 0 deletions openfl-workspace/torch_cnn_histology/src/taskrunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (C) 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""You may copy this file as the starting point of your own model."""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Iterator, Tuple

from openfl.federated import PyTorchTaskRunner
from openfl.utilities import Metric

class PyTorchCNN(PyTorchTaskRunner):
"""
Simple CNN for classification.
PyTorchTaskRunner inherits from nn.module, so you can define your model
in the same way that you would for PyTorch
"""

def __init__(self, device="cpu", **kwargs):
"""Initialize.
Args:
device: The hardware device to use for training (Default = "cpu")
**kwargs: Additional arguments to pass to the function
"""
super().__init__(device=device, **kwargs)

# Define the model
channel = self.data_loader.get_feature_shape()[0] # (channel, dim1, dim2)
self.conv1 = nn.Conv2d(channel, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv5 = nn.Conv2d(128 + 32, 256, kernel_size=3, stride=1, padding=1)
self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv7 = nn.Conv2d(512 + 128 + 32, 256, kernel_size=3, stride=1, padding=1)
self.conv8 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(1184 * 9 * 9, 128)
self.fc2 = nn.Linear(128, 8)

# `self.optimizer` must be set for optimizer weights to be federated
self.optimizer = optim.Adam(self.parameters(), lr=1e-3)

# Set the loss function
self.loss_fn = F.cross_entropy

def forward(self, x):
"""Forward pass of the model.
Args:
x: Data input to the model for the forward pass
"""
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
maxpool = F.max_pool2d(x, 2, 2)

x = F.relu(self.conv3(maxpool))
x = F.relu(self.conv4(x))
concat = torch.cat([maxpool, x], dim=1)
maxpool = F.max_pool2d(concat, 2, 2)

x = F.relu(self.conv5(maxpool))
x = F.relu(self.conv6(x))
concat = torch.cat([maxpool, x], dim=1)
maxpool = F.max_pool2d(concat, 2, 2)

x = F.relu(self.conv7(maxpool))
x = F.relu(self.conv8(x))
concat = torch.cat([maxpool, x], dim=1)
maxpool = F.max_pool2d(concat, 2, 2)

x = maxpool.flatten(start_dim=1)
x = F.dropout(self.fc1(x), p=0.5)
x = self.fc2(x)
return x

def train_(
self, train_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]
) -> Metric:
"""Train single epoch.
Override this function in order to use custom training.
Args:
batch_generator: Train dataset batch generator. Yields (samples, targets) tuples of
size = `self.data_loader.batch_size`.
Returns:
Metric: An object containing name and np.ndarray value.
"""
losses = []
for data, target in train_dataloader:
data, target = torch.tensor(data).to(self.device), torch.tensor(target).to(
self.device
)
self.optimizer.zero_grad()
output = self(data)
loss = self.loss_fn(output, target)
loss.backward()
self.optimizer.step()
losses.append(loss.detach().cpu().numpy())
loss = np.mean(losses)
return Metric(name=self.loss_fn.__name__, value=np.array(loss))

def validate_(
self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]
) -> Metric:
"""
Perform validation on PyTorch Model
Override this function for your own custom validation function
Args:
validation_data_loader: Validation dataset batch generator.
Yields (samples, targets) tuples.
Returns:
Metric: An object containing name and np.ndarray value
"""

total_samples = 0
val_score = 0
with torch.no_grad():
for data, target in validation_dataloader:
samples = target.shape[0]
total_samples += samples
data, target = torch.tensor(data).to(self.device), torch.tensor(
target
).to(self.device, dtype=torch.int64)
output = self(data)
# get the index of the max log-probability
pred = output.argmax(dim=1)
val_score += pred.eq(target).sum().cpu().numpy()

accuracy = val_score / total_samples
return Metric(name="accuracy", value=np.array(accuracy))

0 comments on commit 9c654dd

Please sign in to comment.