Skip to content

Commit

Permalink
get closer to FOB benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed May 12, 2024
1 parent 9344d3c commit d981bda
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 122 deletions.
59 changes: 37 additions & 22 deletions benchmarking/classification/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,46 +26,39 @@ def train_classic(
data: L.LightningDataModule = problem["dataset"](batch_size=problem["batch_size"])
classifier = Classifier(problem["model"](), optimizer=opt, **hyperparameters)

trainer = trainer_from_problem(problem, opt_name=opt.__name__)
trainer = trainer_from_problem(problem, opt_name=opt.__name__, hyperparameters=hyperparameters)
trainer.fit(classifier, data)
trainer.test(classifier, data)


def train_rfd(problem, cov_string, covariance_model):
def train_rfd(problem, hyperparameters):
"""Train a Classifier with RFD"""
data: L.LightningDataModule = problem["dataset"](batch_size=problem["batch_size"])
data.prepare_data()
data.setup("fit")

covariance_model.auto_fit(
model_factory=problem["model"],
loss=F.nll_loss,
data=data.train_data,
cache=f"""cache/{problem["dataset"].__name__}/{problem["model"].__name__}/covariance_cache.csv""",
)

classifier = Classifier(
problem["model"](),
optimizer=RFD,
covariance_model=covariance_model,
b_size_inv=1/problem["batch_size"],
**hyperparameters,
)

trainer = trainer_from_problem(problem, opt_name=f"RFD({cov_string})")
trainer = trainer_from_problem(problem, opt_name="RFD", hyperparameters=hyperparameters)
trainer.fit(classifier, data)
trainer.test(classifier, data)


def trainer_from_problem(problem, opt_name):
problem_id = f"{problem['dataset'].__name__}_{problem['model'].__name__}_{problem['batch_size']}"
def trainer_from_problem(problem, opt_name, hyperparameters):
problem_id = f"{problem['dataset'].__name__}_{problem['model'].__name__}_b={problem['batch_size']}"

tlogger = TensorBoardLogger("logs/TensorBoard", name=problem_id + opt_name)
csvlogger = CSVLogger("logs", name=problem_id + opt_name)
hparams = "_".join([f"{key}={value}" for key, value in hyperparameters.items()])
name = problem_id + "/" + opt_name + "(" + hparams + ")"
tlogger = TensorBoardLogger("logs/TensorBoard", name=name)
csvlogger = CSVLogger("logs", name=name)

tlogger.log_hyperparams(params={"batch_size": problem["batch_size"]})
csvlogger.log_hyperparams(params={"batch_size": problem["batch_size"]})
return L.Trainer(
max_epochs=5,
max_epochs=30,
log_every_n_steps=1,
logger=[tlogger, csvlogger],
)
Expand All @@ -74,14 +67,36 @@ def trainer_from_problem(problem, opt_name):
def main():
problem = {
"dataset": MNIST,
"model": CNN3,
"batch_size": 100,
"model": CNN7,
"batch_size": 1000,
}

# fit covariance model
data: L.LightningDataModule = problem["dataset"](batch_size=problem["batch_size"])
data.prepare_data()
data.setup("fit")
covariance_model = covariance.SquaredExponential()
covariance_model.auto_fit(
model_factory=problem["model"],
loss=F.nll_loss,
data=data.data_train,
cache=f"""cache/{problem["dataset"].__name__}/{problem["model"].__name__}/covariance_cache.csv""",
)
# ------

train_rfd(
problem,
hyperparameters={
"covariance_model": covariance_model,
}
)

train_rfd(
problem,
cov_string="SquaredExponential",
covariance_model=covariance.SquaredExponential(),
hyperparameters={
"covariance_model": covariance_model,
"b_size_inv": 1/problem["batch_size"],
}
)

train_classic(
Expand Down
131 changes: 59 additions & 72 deletions benchmarking/classification/cifar/data.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,50 @@
""" CIFAR Data Modules """

import lightning as L
import torch
from torch.utils.data import random_split, DataLoader

from torchvision.datasets import CIFAR10 as CIFAR10Dataset, CIFAR100 as CIFAR100Dataset
from torchvision import transforms

class CIFAR10(L.LightningDataModule):
"""Represents the CIFAR10 dataset.
"""

def __init__(self, data_dir: str=" ./data", batch_size: int = 100) -> None:
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
# self.transform = transforms.Compose()

def prepare_data(self) -> None:
# download
CIFAR10Dataset(self.data_dir, train=True, download=True)
CIFAR10Dataset(self.data_dir, train=False, download=True)

# pylint: disable=attribute-defined-outside-init
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
cifar10_full = CIFAR10Dataset(self.data_dir, train=True)
self.train_data, self.validation_data = random_split(
cifar10_full, [0.9, 0.1], generator=torch.Generator().manual_seed(42)
)

# Assign test dataset for use in dataloader(s)
if stage == "test":
self.test_data = CIFAR10Dataset(
self.data_dir, train=False
)

if stage == "predict":
self.prediction_data = CIFAR10Dataset(
self.data_dir, train=False
)

def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return DataLoader(self.validation_data, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.test_data, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.prediction_data, batch_size=self.batch_size)
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100 as CIFAR100Dataset
from torchvision.transforms import v2
import lightning as L


class CIFAR100(L.LightningDataModule):
"""Represents the CIFAR10 dataset.
"""
"""Represents the CIFAR10 dataset."""

def __init__(self, data_dir: str=" ./data", batch_size: int = 100) -> None:
def __init__(self, data_dir: str = " ./data", batch_size: int = 100) -> None:
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
# self.transform = transforms.Compose()
# cifar 100 has 60000 32x32 color images (600 images per class)
cifar100_mean = (0.4914, 0.4822, 0.4465)
cifar100_stddev = (0.2023, 0.1994, 0.2010)
random_crop = v2.RandomCrop(
size=32,
padding=4,
padding_mode="reflect",
)
horizontal_flip = v2.RandomHorizontalFlip(0.5)
trivial_augment = v2.TrivialAugmentWide(
interpolation=v2.InterpolationMode.BILINEAR
)

self.train_transforms = v2.Compose(
[
v2.ToImage(),
random_crop,
horizontal_flip,
trivial_augment,
v2.ToDtype(torch.float, scale=True),
v2.Normalize(cifar100_mean, cifar100_stddev),
v2.ToPureTensor(),
]
)
self.val_transforms = v2.Compose(
[
v2.ToImage(),
v2.ToDtype(torch.float, scale=True),
v2.Normalize(cifar100_mean, cifar100_stddev),
v2.ToPureTensor(),
]
)

def prepare_data(self) -> None:
# download
Expand All @@ -72,32 +53,38 @@ def prepare_data(self) -> None:

# pylint: disable=attribute-defined-outside-init
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
"""setup is called from every process across all the nodes. Setting state here is recommended."""
if stage == "fit":
cifar100_full = CIFAR100Dataset(self.data_dir, train=True)
self.train_data, self.validation_data = random_split(
cifar100_full, [0.9, 0.1], generator=torch.Generator().manual_seed(42)
)
self.data_train = self._get_dataset(train=True)
self.data_val = self._get_dataset(train=False)

if stage == "validate":
self.data_val = self._get_dataset(train=False)

# Assign test dataset for use in dataloader(s)
if stage == "test":
self.test_data = CIFAR100Dataset(
self.data_dir, train=False
)
self.data_test = self._get_dataset(train=False)

if stage == "predict":
self.prediction_data = CIFAR100Dataset(
self.data_dir, train=False
self.data_predict = self._get_dataset(train=False)

def _get_dataset(self, train: bool):
if train:
return CIFAR100Dataset(
str(self.data_dir), train=True, transform=self.train_transforms
)
else:
return CIFAR100Dataset(
str(self.data_dir), train=False, transform=self.val_transforms
)

def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
return DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return DataLoader(self.validation_data, batch_size=self.batch_size)
return DataLoader(self.data_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.test_data, batch_size=self.batch_size)
return DataLoader(self.data_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.prediction_data, batch_size=self.batch_size)
return DataLoader(self.data_predict, batch_size=self.batch_size)
2 changes: 1 addition & 1 deletion benchmarking/classification/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, model, optimizer, **hyperameters):
self.optimizer = optimizer
self.model = model
self.hyperparemeters = hyperameters
self.save_hyperparameters(ignore=["model", "optimizer", "covariance_model"])
self.save_hyperparameters(ignore=["model", "optimizer" ])


# pylint: disable=arguments-differ
Expand Down
35 changes: 17 additions & 18 deletions benchmarking/classification/mnist/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import lightning as L
import torch
from torch.utils.data import random_split, DataLoader
from torch.utils.data import DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import (
Expand Down Expand Up @@ -37,35 +36,35 @@ def prepare_data(self):
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
self.train_data = MNISTDataset(
self.data_train = MNISTDataset(
self.data_dir, train=True, transform=self.transform
)
self.validation_data = MNISTDataset(
self.data_val = MNISTDataset(
self.data_dir, train=False, transform=self.transform
)

# Assign test dataset for use in dataloader(s)
if stage == "test":
self.test_data = MNISTDataset(
self.data_test = MNISTDataset(
self.data_dir, train=False, transform=self.transform
)

if stage == "predict":
self.prediction_data = MNISTDataset(
self.data_predict = MNISTDataset(
self.data_dir, train=False, transform=self.transform
)

def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
return DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return DataLoader(self.validation_data, batch_size=self.batch_size)
return DataLoader(self.data_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.test_data, batch_size=self.batch_size)
return DataLoader(self.data_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.prediction_data, batch_size=self.batch_size)
return DataLoader(self.data_predict, batch_size=self.batch_size)


class FashionMNIST(L.LightningDataModule):
Expand Down Expand Up @@ -95,32 +94,32 @@ def prepare_data(self):
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
self.train_data = FashionMNISTDataset(
self.data_train = FashionMNISTDataset(
self.data_dir, train=True, transform=self.transform
)
self.validation_data = FashionMNISTDataset(
self.data_val = FashionMNISTDataset(
self.data_dir, train=False, transform=self.transform
)

# Assign test dataset for use in dataloader(s)
if stage == "test":
self.test_data = FashionMNISTDataset(
self.data_test = FashionMNISTDataset(
self.data_dir, train=False, transform=self.transform
)

if stage == "predict":
self.prediction_data = FashionMNISTDataset(
self.data_predict = FashionMNISTDataset(
self.data_dir, train=False, transform=self.transform
)

def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
return DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return DataLoader(self.validation_data, batch_size=self.batch_size)
return DataLoader(self.data_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.test_data, batch_size=self.batch_size)
return DataLoader(self.data_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.prediction_data, batch_size=self.batch_size)
return DataLoader(self.data_predict, batch_size=self.batch_size)
11 changes: 6 additions & 5 deletions pyrfd/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
if (self.mean is not None) and self.var_reg and self.g_var_reg and self.dims:
self._fitted = True


def __repr__(self) -> str:
var = repr(None)
if self.var_reg:
Expand All @@ -76,11 +77,11 @@ def __repr__(self) -> str:
g_var = f"({self.g_var_reg.intercept}, {self.g_var_reg.slope})"

return (
f"{self.__class__.__name__}(\n"
f" mean={self.mean},\n"
f" variance={var},\n"
f" gradient_var={g_var},\n"
f" dims={self.dims}\n"
f"{self.__class__.__name__}("
f"mean={self.mean}, "
f"variance={var}, "
f"gradient_var={g_var}, "
f"dims={self.dims}"
")"
)

Expand Down
Loading

0 comments on commit d981bda

Please sign in to comment.