Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove assumption that all epochs have same number of batches #411

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions nupic/research/frameworks/pytorch/dataset_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"UnionDataset",
"split_dataset",
"PreprocessedDataset",
"FunctionalPreprocessedDataset",
"CachedDatasetFolder",
"ProgressiveRandomResizedCrop",
"HDF5Dataset",
Expand Down Expand Up @@ -253,6 +254,65 @@ def load_qualifier(self, qualifier):
return file_name


class FunctionalPreprocessedDataset(Dataset):
def __init__(self, cachefilepath, basename, qualifiers):
"""
Like the PreprocessedDataset, but designed to be immutable, using a more
functional programming approach. Rather than calling load_next() and
modifying the dataset object, call get_variant() to get a new dataset
object.

This class only directly provides limited Dataset functionality (it
provides the dataset length). It is not directly compatible with
dataloaders. Call get_variant() to get a dataset that is ready for
dataloaders.

:param cachefilepath: String for the directory containing pre-processed
data.

:param basename: Base file name from which to construct actual file
names. Actual file name will be "basename{}.npz".format(i) where i
cycles through the list of qualifiers.

:param qualifiers: List of qualifiers for each preprocessed files in
this dataset.
"""
self.path = cachefilepath
self.basename = basename
self.qualifiers = qualifiers

# Compute the length once so that __len__ is fast.
self.length = len(self.get_variant(0))

def __getitem__(self, index):
raise TypeError(
"The PreprocessedDataset doesn't support enumeration. "
"Must first use get_variant."
)

def __len__(self):
return self.length

def num_variants(self):
"""
:return: number of variant datasets
"""
return len(self.qualifiers)

def get_variant(self, variant):
"""
:param variant: index of the dataset variant
:return: variant dataset
"""
qualifier = self.qualifiers[variant]
file_name = os.path.join(self.path,
self.basename + "{}.npz".format(qualifier))
x, y = np.load(file_name).values()
return torch.utils.data.TensorDataset(
torch.tensor(x), torch.tensor(y)
)


class CachedDatasetFolder(DatasetFolder):
"""A cached version of `torchvision.datasets.DatasetFolder` where the
classes and image list are static and cached skiping the costly `os.walk`
Expand Down
30 changes: 16 additions & 14 deletions nupic/research/frameworks/pytorch/datasets/gsc_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
import tarfile
import urllib

import numpy as np
import torch.utils.data
from filelock import FileLock
from tqdm import tqdm

from nupic.research.frameworks.pytorch.dataset_utils import PreprocessedDataset
from nupic.research.frameworks.pytorch.dataset_utils import (
FunctionalPreprocessedDataset,
)

DATA_URL = "http://public.numenta.com/datasets/google_speech_commands/gsc_preprocessed_v0.01.tar.gz" # noqa: E501

Expand All @@ -41,9 +45,8 @@ def preprocessed_gsc(root, train=True, download=True):
Create train or validation dataset from preprocessed GSC data, downloading if
necessary.

Warning: Be sure to call dataset.load_next() following each epoch of training.
Otherwise, no new augmentations will be loaded, and the same exact samples
will be reused.
Warning: During training, be sure to call dataset.get_variant(), and call it
with different indices to train with different augmentations.

:param root: directory to store or load downloaded data
:param train: whether to load train of validation data
Expand All @@ -55,17 +58,16 @@ def preprocessed_gsc(root, train=True, download=True):
download_gsc_data(root)

if train:
basename = "gsc_train"
qualifiers = range(30)
dataset = FunctionalPreprocessedDataset(
cachefilepath=root,
basename="gsc_train",
qualifiers=range(30),
)
else:
basename = "gsc_valid"
qualifiers = [""]

dataset = PreprocessedDataset(
cachefilepath=root,
basename=basename,
qualifiers=qualifiers,
)
x, y = np.load(os.path.join(root, "gsc_valid.npz")).values()
dataset = torch.utils.data.TensorDataset(
torch.tensor(x), torch.tensor(y)
)

return dataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,16 @@ def setup_experiment(self, config):
else:
self.model = DataParallel(self.model)

def pre_epoch(self):
super().pre_epoch()
if self.distributed:
self.train_loader.sampler.set_epoch(self.current_epoch)

@classmethod
def create_train_sampler(cls, config, dataset):
def create_train_sampler(cls, config, epoch, dataset, task_indices):
if config.get("distributed", False):
task_indices = cls.compute_task_indices(config, dataset)
return TaskDistributedSampler(
dataset,
task_indices
)
sampler = TaskDistributedSampler(dataset, task_indices)
sampler.set_epoch(epoch)
sampler.set_active_tasks(epoch // config["epochs"])
return sampler
else:
return super().create_train_sampler(config, dataset)
return super().create_train_sampler(config, epoch, dataset,
task_indices)

@classmethod
def get_execution_order(cls):
Expand All @@ -73,10 +68,9 @@ def get_execution_order(cls):

# Extended methods
eo["setup_experiment"].append(exp + ": DistributedDataParallel")
eo["pre_epoch"].append(exp + ": Update distributed sampler")
eo["create_train_sampler"].insert(0,
("If distributed { "
"create distribited sampler "
"create distributed sampler "
"} else {"))
eo["create_train_sampler"].append("}")
# FIXME: Validation is not currently distributed. Implement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,11 @@ def setup_experiment(self, config):
else:
self.model = DataParallel(self.model)

def pre_epoch(self):
super().pre_epoch()
if self.distributed:
self.train_loader.sampler.set_epoch(self.current_epoch)

@classmethod
def create_train_sampler(cls, config, dataset):
def create_train_sampler(cls, config, dataset, epoch):
if config.get("distributed", False):
sampler = DistributedSampler(dataset)
sampler.set_epoch(epoch)
else:
sampler = None
return sampler
Expand Down Expand Up @@ -98,7 +94,6 @@ def get_execution_order(cls):

# Extended methods
eo["setup_experiment"].append(exp + ": DistributedDataParallel")
eo["pre_epoch"].append(exp + ": Update distributed sampler")

eo.update(
# Overwritten methods
Expand Down
45 changes: 0 additions & 45 deletions nupic/research/frameworks/vernon/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,9 @@
#
# http://numenta.org/licenses/
#
import copy
import socket
from contextlib import closing

from torch.optim.lr_scheduler import OneCycleLR

from nupic.research.frameworks.pytorch.lr_scheduler import ComposedLRScheduler


def create_lr_scheduler(optimizer, lr_scheduler_class, lr_scheduler_args,
steps_per_epoch):
"""
Configure learning rate scheduler

:param optimizer:
Wrapped optimizer
:param lr_scheduler_class:
LR scheduler class to use. Must inherit from _LRScheduler
:param lr_scheduler_args:
LR scheduler class constructor arguments
:param steps_per_epoch:
The total number of batches in the epoch.
Only used if lr_scheduler_class is :class:`ComposedLRScheduler` or
:class:`OneCycleLR`
"""
if issubclass(lr_scheduler_class, OneCycleLR):
# Update OneCycleLR parameters
lr_scheduler_args = copy.deepcopy(lr_scheduler_args)
lr_scheduler_args.update(steps_per_epoch=steps_per_epoch)
elif issubclass(lr_scheduler_class, ComposedLRScheduler):
# Update ComposedLRScheduler parameters
lr_scheduler_args = copy.deepcopy(lr_scheduler_args)
schedulers = lr_scheduler_args.get("schedulers", None)
if schedulers is not None:
# Convert dict from ray/json {str:dict} style to {int:dict}
schedulers = {int(k): v for k, v in schedulers.items()}

# Update OneCycleLR "steps_per_epoch" parameter
for _, item in schedulers.items():
lr_class = item.get("lr_scheduler_class", None)
if lr_class is not None and issubclass(lr_class, OneCycleLR):
lr_args = item.get("lr_scheduler_args", {})
lr_args.update(steps_per_epoch=steps_per_epoch)
lr_scheduler_args["schedulers"] = schedulers
lr_scheduler_args["steps_per_epoch"] = steps_per_epoch

return lr_scheduler_class(optimizer, **lr_scheduler_args)


def get_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
Expand Down
90 changes: 81 additions & 9 deletions nupic/research/frameworks/vernon/experiments/cl_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class ContinualLearningExperiment(
def setup_experiment(self, config):

super().setup_experiment(config)

self.train_task_indices = self.compute_task_indices(config,
self.train_dataset)
self.val_task_indices = self.compute_task_indices(config,
self.val_dataset)

# Override epochs to validate to not validate within the inner loop over epochs
self.epochs_to_validate = []

Expand Down Expand Up @@ -90,7 +96,6 @@ def run_task(self):
"""Run outer loop over tasks"""
# configure the sampler to load only samples from current task
self.logger.info("Training...")
self.train_loader.sampler.set_active_tasks(self.current_task)

# Run epochs, inner loop
# TODO: return the results from run_epoch
Expand All @@ -112,15 +117,85 @@ def run_task(self):
self.current_task += 1
return ret

def validate(self, tasks):
loader = self.fast_create_validation_loader(tasks)
return super().validate(loader=loader)

@classmethod
def create_train_sampler(cls, config, dataset):
def create_train_sampler(cls, config, dataset, task_indices, epoch):
current_task = epoch // config["epochs"]
sampler = TaskRandomSampler(task_indices)
sampler.set_active_tasks(current_task)
return sampler

@classmethod
def create_train_dataloader(cls, config, epoch=0):
"""
This classmethod makes it possible to create an experiment's train
dataloaders without instantiating the experiment.

:param config: experiment config
:param epoch: epoch number. subclasses may vary the loader by epoch.
:return: dataloader
"""
dataset = cls.load_dataset(config, train=True)
task_indices = cls.compute_task_indices(config, dataset)
return TaskRandomSampler(task_indices)
sampler = cls.create_train_sampler(config, dataset, task_indices, epoch)

return cls._create_train_dataloader(config, dataset, sampler, epoch)

def fast_create_train_loader(self, epoch):
"""
Like create_train_dataloader, but is an instance method that uses cached
dataset and task_indices objects. This enables using dataloaders in a
functional way, quickly creating them and discarding them, while reusing
the underlying dataset and task_indices objects (which are not mutated).

:param epoch: epoch number
:return: dataloader
"""
sampler = self.create_train_sampler(self.config, self.train_dataset,
self.train_task_indices, epoch)
return self._create_train_dataloader(self.config, self.train_dataset,
sampler, epoch)

@classmethod
def create_validation_sampler(cls, config, dataset):
def create_validation_sampler(cls, config, dataset, task_indices, tasks):
sampler = TaskRandomSampler(task_indices)
sampler.set_active_tasks(tasks)
return sampler

@classmethod
def create_validation_dataloader(cls, config, tasks=0):
"""
This classmethod makes it possible to create an experiment's validation
dataloaders without instantiating the experiment.

:param config: experiment config
:param tasks: task numbers or number
:return: dataloader
"""
dataset = cls.load_dataset(config, train=False)
task_indices = cls.compute_task_indices(config, dataset)
return TaskRandomSampler(task_indices)
sampler = cls.create_validation_sampler(config, dataset, task_indices,
tasks)
return cls._create_validation_dataloader(config, dataset, sampler)

def fast_create_validation_loader(self, tasks):
"""
Like create_validation_dataloader, but is an instance method that uses
cached dataset and task_indices objects. This enables using dataloaders
in a functional way, quickly creating them and discarding them, while
reusing the underlying dataset and task_indices objects (which are not
mutated).

:param tasks: task numbers or number
:return: dataloader
"""
sampler = self.create_validation_sampler(self.config, self.val_dataset,
self.val_task_indices, tasks)
return self._create_validation_dataloader(self.config, self.val_dataset,
sampler)

@classmethod
def compute_task_indices(cls, config, dataset):
Expand Down Expand Up @@ -148,17 +223,14 @@ def get_execution_order(cls):

# Extended methods
eo["setup_experiment"].append(exp + ".setup_experiment")
eo["validate"].insert(0, exp + ": Create loader for specified tasks")

eo.update(
# Overwritten methods
should_stop=[exp + ".should_stop"],
run_iteration=[exp + ": Call run_task"],
create_train_sampler=[exp + ".create_train_sampler"],
create_validation_sampler=[exp + ".create_validation_sampler"],
aggregate_results=[exp + ".aggregate_results"],
aggregate_pre_experiment_results=[
exp + ".aggregate_pre_experiment_results"
],

# New methods
run_task=[exp + ".run_task"],
Expand Down
Loading