Skip to content

Commit

Permalink
Make intention fo distributed datasets more clear
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Sep 4, 2022
1 parent 45d1017 commit 949803b
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 22 deletions.
2 changes: 1 addition & 1 deletion fltk/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class CIFAR10Dataset(Dataset):
"""
CIFAR10 Dataset implementation for Federated learning experiments.
CIFAR10 Dataset implementation for Distributed learning experiments.
"""

def __init__(self, config, learning_param, rank: int = 0, world_size: int = None):
Expand Down
2 changes: 1 addition & 1 deletion fltk/datasets/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class CIFAR100Dataset(Dataset):
"""
CIFAR100 Dataset implementation for Federated learning experiments.
CIFAR100 Dataset implementation for Distributed learning experiments.
"""

DEFAULT_TRANSFORM = transforms.Compose([
Expand Down
44 changes: 24 additions & 20 deletions fltk/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations

import abc
from abc import abstractmethod

import torch
Expand All @@ -11,7 +13,10 @@
from fltk.util.config import DistLearningConfig


class Dataset:
class Dataset(abc.ABC):
"""
Dataset implementation for Distributed learning experiments.
"""

def __init__(self, config, learning_params: DistLearningConfig, rank: int, world_size: int):
self.config = config
Expand Down Expand Up @@ -39,24 +44,6 @@ def get_test_dataset(self):
"""
return self.test_loader

@abstractmethod
def load_train_dataset(self):
"""
Loads & returns the training dataset.
:return: tuple
"""
raise NotImplementedError("load_train_dataset() isn't implemented")

@abstractmethod
def load_test_dataset(self):
"""
Loads & returns the test dataset.
:return: tuple
"""
raise NotImplementedError("load_test_dataset() isn't implemented")

def get_train_loader(self, **kwargs):
"""
Return the data loader for the train dataset.
Expand All @@ -77,7 +64,24 @@ def get_test_loader(self, **kwargs):
"""
return self.test_loader

@staticmethod
@abstractmethod
def load_train_dataset(self):
"""
Loads & returns the training dataset.
:return: tuple
"""
raise NotImplementedError("load_train_dataset() isn't implemented")

@abstractmethod
def load_test_dataset(self):
"""
Loads & returns the test dataset.
:return: tuple
"""
raise NotImplementedError("load_test_dataset() isn't implemented")

def get_data_loader_from_data(batch_size, X, Y, **kwargs):
"""
Get a data loader created from a given set of data.
Expand Down
3 changes: 3 additions & 0 deletions fltk/datasets/fashion_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@


class FashionMNISTDataset(Dataset):
"""
FashionMNIST Dataset implementation for Distributed learning experiments.
"""

def __init__(self, config, learning_param, rank: int = 0, world_size: int = None):
super(FashionMNISTDataset, self).__init__(config, learning_param, rank, world_size)
Expand Down
4 changes: 4 additions & 0 deletions fltk/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@


class MNIST(Dataset):
"""
MNIST Dataset implementation for Distributed learning experiments.
"""

DEFAULT_TRANSFORM = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
Expand Down

0 comments on commit 949803b

Please sign in to comment.