From 949803b0c4ed985ba78df5d2887fa05da3c5232f Mon Sep 17 00:00:00 2001 From: JMGaljaard Date: Fri, 2 Sep 2022 12:56:17 +0200 Subject: [PATCH] Make intention fo distributed datasets more clear --- fltk/datasets/cifar10.py | 2 +- fltk/datasets/cifar100.py | 2 +- fltk/datasets/dataset.py | 44 ++++++++++++++++++---------------- fltk/datasets/fashion_mnist.py | 3 +++ fltk/datasets/mnist.py | 4 ++++ 5 files changed, 33 insertions(+), 22 deletions(-) diff --git a/fltk/datasets/cifar10.py b/fltk/datasets/cifar10.py index e028b426..46f84116 100644 --- a/fltk/datasets/cifar10.py +++ b/fltk/datasets/cifar10.py @@ -7,7 +7,7 @@ class CIFAR10Dataset(Dataset): """ - CIFAR10 Dataset implementation for Federated learning experiments. + CIFAR10 Dataset implementation for Distributed learning experiments. """ def __init__(self, config, learning_param, rank: int = 0, world_size: int = None): diff --git a/fltk/datasets/cifar100.py b/fltk/datasets/cifar100.py index b85aded8..f75eff2c 100644 --- a/fltk/datasets/cifar100.py +++ b/fltk/datasets/cifar100.py @@ -7,7 +7,7 @@ class CIFAR100Dataset(Dataset): """ - CIFAR100 Dataset implementation for Federated learning experiments. + CIFAR100 Dataset implementation for Distributed learning experiments. """ DEFAULT_TRANSFORM = transforms.Compose([ diff --git a/fltk/datasets/dataset.py b/fltk/datasets/dataset.py index c7ee29bb..8bb01d9c 100644 --- a/fltk/datasets/dataset.py +++ b/fltk/datasets/dataset.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import abc from abc import abstractmethod import torch @@ -11,7 +13,10 @@ from fltk.util.config import DistLearningConfig -class Dataset: +class Dataset(abc.ABC): + """ + Dataset implementation for Distributed learning experiments. + """ def __init__(self, config, learning_params: DistLearningConfig, rank: int, world_size: int): self.config = config @@ -39,24 +44,6 @@ def get_test_dataset(self): """ return self.test_loader - @abstractmethod - def load_train_dataset(self): - """ - Loads & returns the training dataset. - - :return: tuple - """ - raise NotImplementedError("load_train_dataset() isn't implemented") - - @abstractmethod - def load_test_dataset(self): - """ - Loads & returns the test dataset. - - :return: tuple - """ - raise NotImplementedError("load_test_dataset() isn't implemented") - def get_train_loader(self, **kwargs): """ Return the data loader for the train dataset. @@ -77,7 +64,24 @@ def get_test_loader(self, **kwargs): """ return self.test_loader - @staticmethod + @abstractmethod + def load_train_dataset(self): + """ + Loads & returns the training dataset. + + :return: tuple + """ + raise NotImplementedError("load_train_dataset() isn't implemented") + + @abstractmethod + def load_test_dataset(self): + """ + Loads & returns the test dataset. + + :return: tuple + """ + raise NotImplementedError("load_test_dataset() isn't implemented") + def get_data_loader_from_data(batch_size, X, Y, **kwargs): """ Get a data loader created from a given set of data. diff --git a/fltk/datasets/fashion_mnist.py b/fltk/datasets/fashion_mnist.py index e03a93f6..8ad23cf6 100644 --- a/fltk/datasets/fashion_mnist.py +++ b/fltk/datasets/fashion_mnist.py @@ -6,6 +6,9 @@ class FashionMNISTDataset(Dataset): + """ + FashionMNIST Dataset implementation for Distributed learning experiments. + """ def __init__(self, config, learning_param, rank: int = 0, world_size: int = None): super(FashionMNISTDataset, self).__init__(config, learning_param, rank, world_size) diff --git a/fltk/datasets/mnist.py b/fltk/datasets/mnist.py index 2c1dd0bc..ab477149 100644 --- a/fltk/datasets/mnist.py +++ b/fltk/datasets/mnist.py @@ -7,6 +7,10 @@ class MNIST(Dataset): + """ + MNIST Dataset implementation for Distributed learning experiments. + """ + DEFAULT_TRANSFORM = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))