Skip to content

Commit

Permalink
Functional programming approach to dataloaders
Browse files Browse the repository at this point in the history
Rather than mutating a central dataloader over time, create new
dataloaders on-demand (while storing internal immutable state like
datasets). This reduces potential for bugs, and it also makes the
create_train_dataloader method fully capable even when called as a
classmethod.
  • Loading branch information
mrcslws committed Nov 16, 2020
1 parent 9c3df2b commit 74d8c76
Show file tree
Hide file tree
Showing 14 changed files with 314 additions and 159 deletions.
53 changes: 53 additions & 0 deletions nupic/research/frameworks/pytorch/dataset_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"UnionDataset",
"split_dataset",
"PreprocessedDataset",
"FunctionalPreprocessedDataset",
"CachedDatasetFolder",
"ProgressiveRandomResizedCrop",
"HDF5Dataset",
Expand Down Expand Up @@ -253,6 +254,58 @@ def load_qualifier(self, qualifier):
return file_name


class FunctionalPreprocessedDataset(Dataset):
def __init__(self, cachefilepath, basename, qualifiers):
"""
Like the PreprocessedDataset, but designed to be immutable, using a more
functional programming approach. Rather than calling load_next() and
modifying the dataset object, call get_variant() to get a new dataset
object.
This class only directly provides limited Dataset functionality (it
provides the dataset length). It is not directly compatible with
dataloaders. Call get_variant() to get a dataset that is ready for
dataloaders.
:param cachefilepath: String for the directory containing pre-processed
data.
:param basename: Base file name from which to construct actual file
names. Actual file name will be "basename{}.npz".format(i) where i
cycles through the list of qualifiers.
:param qualifiers: List of qualifiers for each preprocessed files in
this dataset.
"""
self.path = cachefilepath
self.basename = basename
self.qualifiers = qualifiers

# Compute the length once so that __len__ is fast.
self.length = len(self.get_variant(0))

def __getitem__(self, index):
raise TypeError(
"The PreprocessedDataset doesn't support enumeration. "
"Must first use get_variant."
)

def __len__(self):
return self.length

def num_variants(self):
return len(self.qualifiers)

def get_variant(self, variant):
qualifier = self.qualifiers[variant]
file_name = os.path.join(self.path,
self.basename + "{}.npz".format(qualifier))
x, y = np.load(file_name).values()
return torch.utils.data.TensorDataset(
torch.tensor(x), torch.tensor(y)
)


class CachedDatasetFolder(DatasetFolder):
"""A cached version of `torchvision.datasets.DatasetFolder` where the
classes and image list are static and cached skiping the costly `os.walk`
Expand Down
30 changes: 16 additions & 14 deletions nupic/research/frameworks/pytorch/datasets/gsc_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
import tarfile
import urllib

import numpy as np
import torch.utils.data
from filelock import FileLock
from tqdm import tqdm

from nupic.research.frameworks.pytorch.dataset_utils import PreprocessedDataset
from nupic.research.frameworks.pytorch.dataset_utils import (
FunctionalPreprocessedDataset,
)

DATA_URL = "http://public.numenta.com/datasets/google_speech_commands/gsc_preprocessed_v0.01.tar.gz" # noqa: E501

Expand All @@ -41,9 +45,8 @@ def preprocessed_gsc(root, train=True, download=True):
Create train or validation dataset from preprocessed GSC data, downloading if
necessary.
Warning: Be sure to call dataset.load_next() following each epoch of training.
Otherwise, no new augmentations will be loaded, and the same exact samples
will be reused.
Warning: During training, be sure to call dataset.get_variant(), and call it
with different indices to train with different augmentations.
:param root: directory to store or load downloaded data
:param train: whether to load train of validation data
Expand All @@ -55,17 +58,16 @@ def preprocessed_gsc(root, train=True, download=True):
download_gsc_data(root)

if train:
basename = "gsc_train"
qualifiers = range(30)
dataset = FunctionalPreprocessedDataset(
cachefilepath=root,
basename="gsc_train",
qualifiers=range(30),
)
else:
basename = "gsc_valid"
qualifiers = [""]

dataset = PreprocessedDataset(
cachefilepath=root,
basename=basename,
qualifiers=qualifiers,
)
x, y = np.load(os.path.join(root, "gsc_valid.npz")).values()
dataset = torch.utils.data.TensorDataset(
torch.tensor(x), torch.tensor(y)
)

return dataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,16 @@ def setup_experiment(self, config):
else:
self.model = DataParallel(self.model)

def prepare_loaders_for_epoch(self, epoch):
super().prepare_loaders_for_epoch(epoch)
if self.distributed:
self.train_loader.sampler.set_epoch(epoch)

@classmethod
def create_train_sampler(cls, config, dataset):
def create_train_sampler(cls, config, epoch, dataset, task_indices):
if config.get("distributed", False):
task_indices = cls.compute_task_indices(config, dataset)
return TaskDistributedSampler(
dataset,
task_indices
)
sampler = TaskDistributedSampler(dataset, task_indices)
sampler.set_epoch(epoch)
sampler.set_active_tasks(epoch // config["epochs"])
return sampler
else:
return super().create_train_sampler(config, dataset)
return super().create_train_sampler(config, epoch, dataset,
task_indices)

@classmethod
def get_execution_order(cls):
Expand All @@ -73,11 +68,9 @@ def get_execution_order(cls):

# Extended methods
eo["setup_experiment"].append(exp + ": DistributedDataParallel")
eo["prepare_loaders_for_epoch"].append(
exp + ": Update distributed sampler")
eo["create_train_sampler"].insert(0,
("If distributed { "
"create distribited sampler "
"create distributed sampler "
"} else {"))
eo["create_train_sampler"].append("}")
# FIXME: Validation is not currently distributed. Implement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,11 @@ def setup_experiment(self, config):
else:
self.model = DataParallel(self.model)

def prepare_loaders_for_epoch(self, epoch):
super().prepare_loaders_for_epoch(epoch)
if self.distributed:
self.train_loader.sampler.set_epoch(epoch)

@classmethod
def create_train_sampler(cls, config, dataset):
def create_train_sampler(cls, config, dataset, epoch):
if config.get("distributed", False):
sampler = DistributedSampler(dataset)
sampler.set_epoch(epoch)
else:
sampler = None
return sampler
Expand Down Expand Up @@ -98,8 +94,6 @@ def get_execution_order(cls):

# Extended methods
eo["setup_experiment"].append(exp + ": DistributedDataParallel")
eo["prepare_loaders_for_epoch"].append(
exp + ": Update distributed sampler")

eo.update(
# Overwritten methods
Expand Down
93 changes: 80 additions & 13 deletions nupic/research/frameworks/vernon/experiments/cl_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class ContinualLearningExperiment(
def setup_experiment(self, config):

super().setup_experiment(config)

self.train_task_indices = self.compute_task_indices(config,
self.train_dataset)
self.val_task_indices = self.compute_task_indices(config,
self.val_dataset)

# Override epochs to validate to not validate within the inner loop over epochs
self.epochs_to_validate = []

Expand Down Expand Up @@ -111,20 +117,85 @@ def run_task(self):
self.current_task += 1
return ret

def prepare_loaders_for_epoch(self, epoch):
super().prepare_loaders_for_epoch(epoch)
task = epoch // self.epochs
self.train_loader.sampler.set_active_tasks(task)
def validate(self, tasks):
loader = self.fast_create_validation_loader(tasks)
return super().validate(loader=loader)

@classmethod
def create_train_sampler(cls, config, dataset):
def create_train_sampler(cls, config, dataset, task_indices, epoch):
current_task = epoch // config["epochs"]
sampler = TaskRandomSampler(task_indices)
sampler.set_active_tasks(current_task)
return sampler

@classmethod
def create_train_dataloader(cls, config, epoch=0):
"""
This classmethod makes it possible to create an experiment's train
dataloaders without instantiating the experiment.
:param config: experiment config
:param epoch: epoch number. subclasses may vary the loader by epoch.
:return: dataloader
"""
dataset = cls.load_dataset(config, train=True)
task_indices = cls.compute_task_indices(config, dataset)
return TaskRandomSampler(task_indices)
sampler = cls.create_train_sampler(config, dataset, task_indices, epoch)

return cls._create_train_dataloader(config, dataset, sampler, epoch)

def fast_create_train_loader(self, epoch):
"""
Like create_train_dataloader, but is an instance method that uses cached
dataset and task_indices objects. This enables using dataloaders in a
functional way, quickly creating them and discarding them, while reusing
the underlying dataset and task_indices objects (which are not mutated).
:param epoch: epoch number
:return: dataloader
"""
sampler = self.create_train_sampler(self.config, self.train_dataset,
self.train_task_indices, epoch)
return self._create_train_dataloader(self.config, self.train_dataset,
sampler, epoch)

@classmethod
def create_validation_sampler(cls, config, dataset):
def create_validation_sampler(cls, config, dataset, task_indices, tasks):
sampler = TaskRandomSampler(task_indices)
sampler.set_active_tasks(tasks)
return sampler

@classmethod
def create_validation_dataloader(cls, config, tasks=0):
"""
This classmethod makes it possible to create an experiment's validation
dataloaders without instantiating the experiment.
:param config: experiment config
:param epoch: epoch number. subclasses may vary the loader by epoch.
:return: dataloader
"""
dataset = cls.load_dataset(config, train=False)
task_indices = cls.compute_task_indices(config, dataset)
return TaskRandomSampler(task_indices)
sampler = cls.create_validation_sampler(config, dataset, task_indices,
tasks)
return cls._create_validation_dataloader(config, dataset, sampler)

def fast_create_validation_loader(self, tasks):
"""
Like create_validation_dataloader, but is an instance method that uses
cached dataset and task_indices objects. This enables using dataloaders
in a functional way, quickly creating them and discarding them, while
reusing the underlying dataset and task_indices objects (which are not
mutated).
:param epoch: epoch number
:return: dataloader
"""
sampler = self.create_validation_sampler(self.config, self.val_dataset,
self.val_task_indices, tasks)
return self._create_validation_dataloader(self.config, self.val_dataset,
sampler)

@classmethod
def compute_task_indices(cls, config, dataset):
Expand Down Expand Up @@ -152,18 +223,14 @@ def get_execution_order(cls):

# Extended methods
eo["setup_experiment"].append(exp + ".setup_experiment")
eo["prepare_loaders_for_epoch"].append(exp + ": Set current task")
eo["validate"].insert(0, exp + ": Create loader for specified tasks")

eo.update(
# Overwritten methods
should_stop=[exp + ".should_stop"],
run_iteration=[exp + ": Call run_task"],
create_train_sampler=[exp + ".create_train_sampler"],
create_validation_sampler=[exp + ".create_validation_sampler"],
aggregate_results=[exp + ".aggregate_results"],
aggregate_pre_experiment_results=[
exp + ".aggregate_pre_experiment_results"
],

# New methods
run_task=[exp + ".run_task"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,33 @@ def eval_current_task(self):
"""
Evaluates accuracy at current task only. Used for debugging.
"""
self.val_loader.sampler.set_active_tasks(self.current_task)
return self.validate()
return self.validate(self.current_task)

def eval_first_task(self):
"""
Evaluates accuracy at first task only. Used for debugging.
"""
self.val_loader.sampler.set_active_tasks(0)
return self.validate()
return self.validate(tasks=0)

def eval_all_visited_tasks(self):
"""
Evaluates all tasks seen so far jointly. Equivalent to average accuracy
"""
self.val_loader.sampler.set_active_tasks(range(0, self.current_task + 1))
return self.validate()
return self.validate(tasks=range(self.current_task + 1))

def eval_all_tasks(self):
"""
Evaluates all tasks, including visited and not visited tasks.
"""
self.val_loader.sampler.set_active_tasks(range(0, self.num_tasks))
return self.validate()
return self.validate(tasks=range(self.num_tasks))

def eval_individual_tasks(self):
"""
Most common scenario in continual learning.
Evaluates all tasks seen so far, and report accuracy individually.
"""
task_results = {}
for task_id in range(0, self.current_task + 1):
self.val_loader.sampler.set_active_tasks(task_id)
for k, v in self.validate().items():
for task_id in range(self.current_task + 1):
for k, v in self.validate(tasks=task_id).items():
task_results[f"task{task_id}__{k}"] = v
return task_results
Loading

0 comments on commit 74d8c76

Please sign in to comment.