Skip to content

Commit

Permalink
Clean up definition of datasets to differentiate between Distributed …
Browse files Browse the repository at this point in the history
…and Federated datasets
  • Loading branch information
JMGaljaard committed Sep 4, 2022
1 parent 4b3128b commit 45d1017
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 23 deletions.
33 changes: 33 additions & 0 deletions fltk/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,36 @@
from .cifar100 import CIFAR100Dataset
from .fashion_mnist import FashionMNISTDataset
from .mnist import MNIST
from .dataset import Dataset

def available_dataparallel_datasets():
return {
Dataset.cifar10: CIFAR10Dataset,
Dataset.cifar100: CIFAR100Dataset,
Dataset.fashion_mnist: FashionMNISTDataset,
Dataset.mnist: MNIST
}


def get_train_loader_path(name: Dataset) -> str:
paths = {
Dataset.cifar10: 'data_loaders/cifar10/train_data_loader.pickle',
Dataset.fashion_mnist: 'data_loaders/fashion-mnist/train_data_loader.pickle',
Dataset.cifar100: 'data_loaders/cifar100/train_data_loader.pickle',
Dataset.mnist: 'data_loaders/mnist/train_data_loader.pickle',
}
return paths[name]


def get_test_loader_path(name: Dataset) -> str:
paths = {
Dataset.cifar10: 'data_loaders/cifar10/test_data_loader.pickle',
Dataset.fashion_mnist: 'data_loaders/fashion-mnist/test_data_loader.pickle',
Dataset.cifar100: 'data_loaders/cifar100/test_data_loader.pickle',
Dataset.mnist: 'data_loaders/mnist/test_data_loader.pickle',
}
return paths[name]


def get_dist_dataset(name: Dataset):
return available_dataparallel_datasets()[name]
2 changes: 1 addition & 1 deletion fltk/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torchvision import datasets
from torchvision import transforms

from .dataset import Dataset
from fltk.datasets.dataset import Dataset


class CIFAR10Dataset(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion fltk/datasets/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torchvision import datasets
from torchvision import transforms

from .dataset import Dataset
from fltk.datasets.dataset import Dataset


class CIFAR100Dataset(Dataset):
Expand Down
6 changes: 5 additions & 1 deletion fltk/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations
from abc import abstractmethod

import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

from fltk.util.config import DistLearningConfig
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from fltk.util.config import DistLearningConfig


class Dataset:
Expand Down
2 changes: 1 addition & 1 deletion fltk/datasets/fashion_mnist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=missing-function-docstring,missing-class-docstring,invalid-name
from .dataset import Dataset
from fltk.datasets.dataset import Dataset
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, DistributedSampler
Expand Down
24 changes: 19 additions & 5 deletions fltk/datasets/federated/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
from .cifar10 import DistCIFAR10Dataset
from .cifar100 import DistCIFAR100Dataset
from .fashion_mnist import DistFashionMNISTDataset
from .mnist import DistMNISTDataset
from .dataset import DistDataset
from .cifar10 import FedCIFAR10Dataset
from .cifar100 import FedCIFAR100Dataset
from .fashion_mnist import FedFashionMNISTDataset
from .mnist import FedMNISTDataset
from .dataset import FedDataset
from ...util.config.definitions import Dataset


def available_fed_datasets():
return {
Dataset.cifar10: FedCIFAR10Dataset,
Dataset.cifar100: FedCIFAR100Dataset,
Dataset.fashion_mnist: FedFashionMNISTDataset,
Dataset.mnist: FedMNISTDataset
}


def get_fed_dataset(name: Dataset):
return available_fed_datasets()[name]
6 changes: 3 additions & 3 deletions fltk/datasets/federated/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
from torchvision import datasets
from torchvision import transforms

from fltk.datasets.distributed import DistDataset
from fltk.datasets.federated.dataset import FedDataset
from fltk.samplers import get_sampler
from fltk.util.config import FedLearningConfig


class DistCIFAR10Dataset(DistDataset):
class FedCIFAR10Dataset(FedDataset):
"""
CIFAR10 Dataset implementation for Distributed learning experiments.
"""

def __init__(self, args: FedLearningConfig):
super(DistCIFAR10Dataset, self).__init__(args)
super(FedCIFAR10Dataset, self).__init__(args)
self.init_train_dataset()
self.init_test_dataset()

Expand Down
6 changes: 3 additions & 3 deletions fltk/datasets/federated/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from fltk.datasets.distributed.dataset import DistDataset
from fltk.datasets.federated.dataset import FedDataset
from fltk.samplers import get_sampler


class DistCIFAR100Dataset(DistDataset):
class FedCIFAR100Dataset(FedDataset):
"""
CIFAR100 Dataset implementation for Distributed learning experiments.
"""

def __init__(self, args):
super(DistCIFAR100Dataset, self).__init__(args)
super(FedCIFAR100Dataset, self).__init__(args)
self.init_train_dataset()
self.init_test_dataset()

Expand Down
2 changes: 1 addition & 1 deletion fltk/datasets/federated/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fltk.util.log import getLogger


class DistDataset:
class FedDataset:
train_sampler = None
test_sampler = None
train_dataset = None
Expand Down
6 changes: 3 additions & 3 deletions fltk/datasets/federated/fashion_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from torchvision import datasets
from torchvision import transforms

from fltk.datasets.distributed.dataset import DistDataset
from fltk.datasets.federated.dataset import FedDataset
from fltk.samplers import get_sampler


class DistFashionMNISTDataset(DistDataset):
class FedFashionMNISTDataset(FedDataset):

def __init__(self, args):
super(DistFashionMNISTDataset, self).__init__(args)
super(FedFashionMNISTDataset, self).__init__(args)
self.init_train_dataset()
self.init_test_dataset()

Expand Down
6 changes: 3 additions & 3 deletions fltk/datasets/federated/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from fltk.datasets.distributed.dataset import DistDataset
from fltk.datasets.federated.dataset import FedDataset
from fltk.samplers import get_sampler

if TYPE_CHECKING:
pass

class DistMNISTDataset(DistDataset):
class FedMNISTDataset(FedDataset):

def __init__(self, args):
super(DistMNISTDataset, self).__init__(args)
super(FedMNISTDataset, self).__init__(args)
self.init_train_dataset()
self.init_test_dataset()

Expand Down
2 changes: 1 addition & 1 deletion fltk/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torchvision import datasets
from torchvision import transforms

from .dataset import Dataset
from fltk.datasets.dataset import Dataset


class MNIST(Dataset):
Expand Down

0 comments on commit 45d1017

Please sign in to comment.