From 45d10175b8b2d27e7d6abb84dc7be552bd2cf15b Mon Sep 17 00:00:00 2001 From: JMGaljaard Date: Fri, 2 Sep 2022 12:51:48 +0200 Subject: [PATCH] Clean up definition of datasets to differentiate between Distributed and Federated datasets --- fltk/datasets/__init__.py | 33 ++++++++++++++++++++++++ fltk/datasets/cifar10.py | 2 +- fltk/datasets/cifar100.py | 2 +- fltk/datasets/dataset.py | 6 ++++- fltk/datasets/fashion_mnist.py | 2 +- fltk/datasets/federated/__init__.py | 24 +++++++++++++---- fltk/datasets/federated/cifar10.py | 6 ++--- fltk/datasets/federated/cifar100.py | 6 ++--- fltk/datasets/federated/dataset.py | 2 +- fltk/datasets/federated/fashion_mnist.py | 6 ++--- fltk/datasets/federated/mnist.py | 6 ++--- fltk/datasets/mnist.py | 2 +- 12 files changed, 74 insertions(+), 23 deletions(-) diff --git a/fltk/datasets/__init__.py b/fltk/datasets/__init__.py index 6caa5832..a44e82ac 100644 --- a/fltk/datasets/__init__.py +++ b/fltk/datasets/__init__.py @@ -2,3 +2,36 @@ from .cifar100 import CIFAR100Dataset from .fashion_mnist import FashionMNISTDataset from .mnist import MNIST +from .dataset import Dataset + +def available_dataparallel_datasets(): + return { + Dataset.cifar10: CIFAR10Dataset, + Dataset.cifar100: CIFAR100Dataset, + Dataset.fashion_mnist: FashionMNISTDataset, + Dataset.mnist: MNIST + } + + +def get_train_loader_path(name: Dataset) -> str: + paths = { + Dataset.cifar10: 'data_loaders/cifar10/train_data_loader.pickle', + Dataset.fashion_mnist: 'data_loaders/fashion-mnist/train_data_loader.pickle', + Dataset.cifar100: 'data_loaders/cifar100/train_data_loader.pickle', + Dataset.mnist: 'data_loaders/mnist/train_data_loader.pickle', + } + return paths[name] + + +def get_test_loader_path(name: Dataset) -> str: + paths = { + Dataset.cifar10: 'data_loaders/cifar10/test_data_loader.pickle', + Dataset.fashion_mnist: 'data_loaders/fashion-mnist/test_data_loader.pickle', + Dataset.cifar100: 'data_loaders/cifar100/test_data_loader.pickle', + Dataset.mnist: 'data_loaders/mnist/test_data_loader.pickle', + } + return paths[name] + + +def get_dist_dataset(name: Dataset): + return available_dataparallel_datasets()[name] diff --git a/fltk/datasets/cifar10.py b/fltk/datasets/cifar10.py index a0c6ea3a..e028b426 100644 --- a/fltk/datasets/cifar10.py +++ b/fltk/datasets/cifar10.py @@ -2,7 +2,7 @@ from torchvision import datasets from torchvision import transforms -from .dataset import Dataset +from fltk.datasets.dataset import Dataset class CIFAR10Dataset(Dataset): diff --git a/fltk/datasets/cifar100.py b/fltk/datasets/cifar100.py index d6e6b47c..b85aded8 100644 --- a/fltk/datasets/cifar100.py +++ b/fltk/datasets/cifar100.py @@ -2,7 +2,7 @@ from torchvision import datasets from torchvision import transforms -from .dataset import Dataset +from fltk.datasets.dataset import Dataset class CIFAR100Dataset(Dataset): diff --git a/fltk/datasets/dataset.py b/fltk/datasets/dataset.py index 7cb94832..c7ee29bb 100644 --- a/fltk/datasets/dataset.py +++ b/fltk/datasets/dataset.py @@ -1,10 +1,14 @@ +from __future__ import annotations from abc import abstractmethod import torch from torch.utils.data import DataLoader from torch.utils.data import TensorDataset -from fltk.util.config import DistLearningConfig +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fltk.util.config import DistLearningConfig class Dataset: diff --git a/fltk/datasets/fashion_mnist.py b/fltk/datasets/fashion_mnist.py index 801e693e..e03a93f6 100644 --- a/fltk/datasets/fashion_mnist.py +++ b/fltk/datasets/fashion_mnist.py @@ -1,5 +1,5 @@ # pylint: disable=missing-function-docstring,missing-class-docstring,invalid-name -from .dataset import Dataset +from fltk.datasets.dataset import Dataset from torchvision import datasets from torchvision import transforms from torch.utils.data import DataLoader, DistributedSampler diff --git a/fltk/datasets/federated/__init__.py b/fltk/datasets/federated/__init__.py index 2934cf74..c7ca4b16 100644 --- a/fltk/datasets/federated/__init__.py +++ b/fltk/datasets/federated/__init__.py @@ -1,5 +1,19 @@ -from .cifar10 import DistCIFAR10Dataset -from .cifar100 import DistCIFAR100Dataset -from .fashion_mnist import DistFashionMNISTDataset -from .mnist import DistMNISTDataset -from .dataset import DistDataset +from .cifar10 import FedCIFAR10Dataset +from .cifar100 import FedCIFAR100Dataset +from .fashion_mnist import FedFashionMNISTDataset +from .mnist import FedMNISTDataset +from .dataset import FedDataset +from ...util.config.definitions import Dataset + + +def available_fed_datasets(): + return { + Dataset.cifar10: FedCIFAR10Dataset, + Dataset.cifar100: FedCIFAR100Dataset, + Dataset.fashion_mnist: FedFashionMNISTDataset, + Dataset.mnist: FedMNISTDataset + } + + +def get_fed_dataset(name: Dataset): + return available_fed_datasets()[name] diff --git a/fltk/datasets/federated/cifar10.py b/fltk/datasets/federated/cifar10.py index 1d62e45d..b9bfb330 100644 --- a/fltk/datasets/federated/cifar10.py +++ b/fltk/datasets/federated/cifar10.py @@ -2,18 +2,18 @@ from torchvision import datasets from torchvision import transforms -from fltk.datasets.distributed import DistDataset +from fltk.datasets.federated.dataset import FedDataset from fltk.samplers import get_sampler from fltk.util.config import FedLearningConfig -class DistCIFAR10Dataset(DistDataset): +class FedCIFAR10Dataset(FedDataset): """ CIFAR10 Dataset implementation for Distributed learning experiments. """ def __init__(self, args: FedLearningConfig): - super(DistCIFAR10Dataset, self).__init__(args) + super(FedCIFAR10Dataset, self).__init__(args) self.init_train_dataset() self.init_test_dataset() diff --git a/fltk/datasets/federated/cifar100.py b/fltk/datasets/federated/cifar100.py index 0603cc67..22979b53 100644 --- a/fltk/datasets/federated/cifar100.py +++ b/fltk/datasets/federated/cifar100.py @@ -2,17 +2,17 @@ from torchvision import datasets from torchvision import transforms from torch.utils.data import DataLoader -from fltk.datasets.distributed.dataset import DistDataset +from fltk.datasets.federated.dataset import FedDataset from fltk.samplers import get_sampler -class DistCIFAR100Dataset(DistDataset): +class FedCIFAR100Dataset(FedDataset): """ CIFAR100 Dataset implementation for Distributed learning experiments. """ def __init__(self, args): - super(DistCIFAR100Dataset, self).__init__(args) + super(FedCIFAR100Dataset, self).__init__(args) self.init_train_dataset() self.init_test_dataset() diff --git a/fltk/datasets/federated/dataset.py b/fltk/datasets/federated/dataset.py index 4a5cc56b..84503bea 100644 --- a/fltk/datasets/federated/dataset.py +++ b/fltk/datasets/federated/dataset.py @@ -4,7 +4,7 @@ from fltk.util.log import getLogger -class DistDataset: +class FedDataset: train_sampler = None test_sampler = None train_dataset = None diff --git a/fltk/datasets/federated/fashion_mnist.py b/fltk/datasets/federated/fashion_mnist.py index 0ba4c964..64e80ddf 100644 --- a/fltk/datasets/federated/fashion_mnist.py +++ b/fltk/datasets/federated/fashion_mnist.py @@ -3,14 +3,14 @@ from torchvision import datasets from torchvision import transforms -from fltk.datasets.distributed.dataset import DistDataset +from fltk.datasets.federated.dataset import FedDataset from fltk.samplers import get_sampler -class DistFashionMNISTDataset(DistDataset): +class FedFashionMNISTDataset(FedDataset): def __init__(self, args): - super(DistFashionMNISTDataset, self).__init__(args) + super(FedFashionMNISTDataset, self).__init__(args) self.init_train_dataset() self.init_test_dataset() diff --git a/fltk/datasets/federated/mnist.py b/fltk/datasets/federated/mnist.py index 58824bff..e2c26246 100644 --- a/fltk/datasets/federated/mnist.py +++ b/fltk/datasets/federated/mnist.py @@ -6,16 +6,16 @@ from torch.utils.data import DataLoader from torchvision import datasets, transforms -from fltk.datasets.distributed.dataset import DistDataset +from fltk.datasets.federated.dataset import FedDataset from fltk.samplers import get_sampler if TYPE_CHECKING: pass -class DistMNISTDataset(DistDataset): +class FedMNISTDataset(FedDataset): def __init__(self, args): - super(DistMNISTDataset, self).__init__(args) + super(FedMNISTDataset, self).__init__(args) self.init_train_dataset() self.init_test_dataset() diff --git a/fltk/datasets/mnist.py b/fltk/datasets/mnist.py index 4711f126..2c1dd0bc 100644 --- a/fltk/datasets/mnist.py +++ b/fltk/datasets/mnist.py @@ -3,7 +3,7 @@ from torchvision import datasets from torchvision import transforms -from .dataset import Dataset +from fltk.datasets.dataset import Dataset class MNIST(Dataset):