-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial github copilot based refactor * learning rate multiples * some better descriptions * plotting gets own file * forgotten return * linting * redo qq plots * some bugfixes * towards pytorch lightning * start introducing metrics * towards a working benchmark * working training * benchmark directory * adjust learning rate * remove wandb * batch size for learning rate * generalize sample cache * visualize covariance * bug in lightning? * factour out regression and simplify repr * oversampled MNIST_CNN7 * get closer to FOB benchmarks * copy implementation from FOB * computer modern font * remove title add x-label * remove mnistSimpleCNN folder * move plotting to its own space * remove predict ugliness * move appendix * add rational quadratic implementation * formatting * caches * fix comments * logging everything about mnist * fix bug * savefig * gitignore * learning rate output * norm lock feature * back to mnist * plot metrics * plot step behavior * plot_metrics * add logs * more metrics * more logs * metrics and covariance fit * more logs * multiple runs of covariance fit * add logs for cnn3 * logs and new plotting * 128 logs * more logging * fix logging * make linters happy * conservative rfd * tested conservatism * black formatter * logs * more logs * fix type annotations * legacy workarounds * add conservatism to optimizer * more logs * fix bug * more logs * fashion mnist logs * fashion mnist logs * fix bug * algo - perf addition * AlgoPerf * logs of conservatism for debugging * logs of conservatism * more plotting machinery * better covariance tests * identify conservatism problem * fix bug * repo cleanup * fix benchmarking module structure * small refactor --------- Co-authored-by: Simon Forbat <simon.forbat@uni-mannheim.de> Co-authored-by: Simon Forbat <54510453+simon-forb@users.noreply.github.com>
- Loading branch information
1 parent
243e740
commit cefa998
Showing
85 changed files
with
11,299 additions
and
1,018 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
[flake8] | ||
max-line-length = 120 | ||
max-line-length = 120 | ||
|
||
# E266: too many leading '#' for block commment | ||
extend-ignore = E266 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
{ | ||
"python.testing.pytestArgs": [ | ||
"." | ||
"tests" | ||
], | ||
"python.testing.unittestEnabled": false, | ||
"python.testing.pytestEnabled": true | ||
"python.testing.pytestEnabled": true, | ||
"python.defaultInterpreterPath": ".venv/bin/python" | ||
} |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" Classification examples. """ | ||
|
||
from .classifier import Classifier | ||
from . import mnist | ||
|
||
__all__ = ["Classifier", "mnist"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
""" Benchmarking RFD """ | ||
|
||
import torch | ||
from torch import optim, nn | ||
import torch.nn.functional as F | ||
import lightning as L | ||
import sys | ||
|
||
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger | ||
|
||
# pylint: disable=import-error,wrong-import-order | ||
from benchmarking.classification.cifar.data import CIFAR100 | ||
from benchmarking.classification.cifar.resnet18 import Resnet18 | ||
from benchmarking.classification.mnist.data import MNIST, FashionMNIST | ||
from benchmarking.classification.mnist.models import CNN3, CNN5, CNN7, AlgoPerf # pylint: disable=unused-import | ||
from benchmarking.classification.classifier import Classifier | ||
|
||
from pyrfd import RFD, covariance | ||
|
||
|
||
def train(problem, opt: optim.Optimizer, hyperparameters): | ||
"""Train a Classifier """ | ||
|
||
data: L.LightningDataModule = problem["dataset"](batch_size=problem["batch_size"]) | ||
classifier = Classifier( | ||
problem["model"](), | ||
optimizer=opt, | ||
loss=problem["loss"], | ||
**hyperparameters, | ||
) | ||
|
||
trainer = trainer_from_problem(problem, opt_name=opt.__name__, hyperparameters=hyperparameters) | ||
trainer.fit(classifier, data) | ||
trainer.test(classifier, data) | ||
|
||
|
||
def trainer_from_problem(problem, opt_name, hyperparameters): | ||
problem_id = f"{problem['dataset'].__name__}_{problem['model'].__name__}_b={problem['batch_size']}" | ||
hparams = "_".join([f"{key}={value}" for key, value in hyperparameters.items()]) | ||
name = problem_id + "/" + opt_name + "(" + hparams + ")" + f"/seed={problem['seed']}" | ||
|
||
tlogger = TensorBoardLogger("logs/TensorBoard", name=name) | ||
csvlogger = CSVLogger("logs", name=name) | ||
|
||
L.seed_everything(problem["seed"], workers=True) | ||
|
||
return L.Trainer( | ||
**problem["trainer_params"], | ||
logger=[tlogger, csvlogger], | ||
) | ||
|
||
|
||
PROBLEMS = { | ||
"FashionMNIST_CNN5" : { | ||
"dataset": FashionMNIST, | ||
"model": CNN5, | ||
"loss": F.nll_loss, | ||
"batch_size": 128, | ||
"seed": 42, | ||
"tol": 0.3, | ||
"trainer_params": { | ||
"max_epochs": 30, | ||
"log_every_n_steps": 1, | ||
} | ||
}, | ||
"MNIST_AlgoPerf":{ | ||
"dataset": MNIST, | ||
"model": AlgoPerf, | ||
"loss": F.nll_loss, | ||
"batch_size": 128, | ||
"seed": 42, | ||
"tol": 0.3, | ||
"trainer_params": { | ||
"max_epochs": 30, | ||
"log_every_n_steps": 1, | ||
} | ||
}, | ||
"MNIST_CNN3" : { | ||
"dataset": MNIST, | ||
"model": CNN3, | ||
"loss": F.nll_loss, | ||
"batch_size": 128, | ||
"seed": 42, | ||
"tol": 0.3, | ||
"trainer_params": { | ||
"max_epochs": 30, | ||
"log_every_n_steps": 1, | ||
} | ||
}, | ||
"MNIST_CNN7" : { | ||
"dataset": MNIST, | ||
"model": CNN7, | ||
"loss": F.nll_loss, | ||
"batch_size": 1024, | ||
"seed": 42, | ||
"tol": 0.3, | ||
"trainer_params": { | ||
"max_epochs": 30, | ||
"log_every_n_steps": 1, | ||
} | ||
}, | ||
"CIFAR100_resnet18": { | ||
"dataset": CIFAR100, | ||
"model": Resnet18, | ||
"loss": nn.CrossEntropyLoss(label_smoothing=0), | ||
"batch_size": 1024, | ||
"seed": 42, | ||
"tol": 0.3, | ||
"trainer_params": { | ||
"max_epochs": 50, | ||
"log_every_n_steps": 5, | ||
} | ||
}, | ||
} | ||
|
||
def main(problem_name, opt): | ||
problem = PROBLEMS[problem_name] | ||
|
||
# fit covariance model | ||
data: L.LightningDataModule = problem["dataset"](batch_size=problem["batch_size"]) | ||
data.prepare_data() | ||
data.setup("fit") | ||
|
||
if opt == "cov": | ||
for run in range(20): | ||
sq_exp_cov_model = covariance.SquaredExponential() | ||
sq_exp_cov_model.auto_fit( | ||
model_factory=problem["model"], | ||
loss=problem["loss"], | ||
data=data.data_train, | ||
tol=problem['tol'], | ||
cache=f"""cache/{problem["dataset"].__name__}/{problem["model"].__name__}_run={run}/covariance_cache.csv""", | ||
) | ||
|
||
torch.set_float32_matmul_precision("high") | ||
|
||
if opt == "RFD-SE": | ||
sq_exp_cov_model = covariance.SquaredExponential() | ||
sq_exp_cov_model.auto_fit( | ||
model_factory=problem["model"], | ||
loss=problem["loss"], | ||
data=data.data_train, | ||
tol=problem['tol'], | ||
cache=f"""cache/{problem["dataset"].__name__}/{problem["model"].__name__}/covariance_cache.csv""", | ||
) | ||
|
||
for seed in range(20): | ||
problem["seed"] = seed | ||
train( | ||
problem, | ||
opt=RFD, | ||
hyperparameters={ | ||
"covariance_model": sq_exp_cov_model, | ||
"conservatism": 0.1, | ||
}, | ||
) | ||
train( | ||
problem, | ||
opt=RFD, | ||
hyperparameters={ | ||
"covariance_model": sq_exp_cov_model, | ||
"b_size_inv": 1/problem["batch_size"], | ||
"conservatism": 0.1, | ||
}, | ||
) | ||
|
||
if opt == "RFD-RQ": | ||
rat_quad_cov_model = covariance.RationalQuadratic(beta=1) | ||
rat_quad_cov_model.auto_fit( | ||
model_factory=problem["model"], | ||
loss=problem["loss"], | ||
data=data.data_train, | ||
tol=problem['tol'], | ||
cache=f"""cache/{problem["dataset"].__name__}/{problem["model"].__name__}/covariance_cache.csv""", | ||
) | ||
for seed in range(20): | ||
problem["seed"] = seed | ||
train( | ||
problem, | ||
opt=RFD, | ||
hyperparameters={ | ||
"covariance_model": rat_quad_cov_model, | ||
}, | ||
) | ||
|
||
if opt == "Adam": | ||
for seed in range(20): | ||
problem["seed"] = seed | ||
for lr in [ | ||
1e-1, | ||
1e-2, | ||
1e-3, | ||
1e-4 | ||
]: | ||
train( | ||
problem, | ||
opt=optim.Adam, | ||
hyperparameters={ | ||
"lr": lr, | ||
"betas": (0.9, 0.999), | ||
}, | ||
) | ||
|
||
if opt == "SGD": | ||
for seed in range(20): | ||
problem["seed"] = seed | ||
for lr in [1e1, 1e0, 1e-1, 1e-2]: | ||
train( | ||
problem, | ||
opt=optim.SGD, | ||
hyperparameters={ | ||
"lr": lr | ||
}, | ||
) | ||
|
||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main(sys.argv[1], sys.argv[2]) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
|
||
from torch.utils.data import DataLoader | ||
from torchvision.datasets import CIFAR100 as CIFAR100Dataset | ||
from torchvision.transforms import v2 | ||
import lightning as L | ||
|
||
|
||
class CIFAR100(L.LightningDataModule): | ||
"""Represents the CIFAR100 dataset.""" | ||
|
||
def __init__(self, data_dir: str = "./data/CIFAR", batch_size: int = 100) -> None: | ||
super().__init__() | ||
self.data_dir = data_dir | ||
self.batch_size = batch_size | ||
# cifar 100 has 60000 32x32 color images (600 images per class) | ||
cifar100_mean = (0.4914, 0.4822, 0.4465) | ||
cifar100_stddev = (0.2023, 0.1994, 0.2010) | ||
random_crop = v2.RandomCrop( | ||
size=32, | ||
padding=4, | ||
padding_mode="reflect", | ||
) | ||
horizontal_flip = v2.RandomHorizontalFlip(0.5) | ||
trivial_augment = v2.TrivialAugmentWide( | ||
interpolation=v2.InterpolationMode.BILINEAR | ||
) | ||
|
||
self.train_transforms = v2.Compose( | ||
[ | ||
v2.ToImage(), | ||
random_crop, | ||
horizontal_flip, | ||
trivial_augment, | ||
v2.ToDtype(torch.float, scale=True), | ||
v2.Normalize(cifar100_mean, cifar100_stddev), | ||
v2.ToPureTensor(), | ||
] | ||
) | ||
self.val_transforms = v2.Compose( | ||
[ | ||
v2.ToImage(), | ||
v2.ToDtype(torch.float, scale=True), | ||
v2.Normalize(cifar100_mean, cifar100_stddev), | ||
v2.ToPureTensor(), | ||
] | ||
) | ||
|
||
def prepare_data(self) -> None: | ||
# download | ||
CIFAR100Dataset(self.data_dir, train=True, download=True) | ||
CIFAR100Dataset(self.data_dir, train=False, download=True) | ||
|
||
# pylint: disable=attribute-defined-outside-init | ||
def setup(self, stage: str): | ||
"""setup is called from every process across all the nodes. Setting state here is recommended.""" | ||
if stage == "fit": | ||
self.data_train = self._get_dataset(train=True) | ||
self.data_val = self._get_dataset(train=False) | ||
|
||
if stage == "validate": | ||
self.data_val = self._get_dataset(train=False) | ||
|
||
if stage == "test": | ||
self.data_test = self._get_dataset(train=False) | ||
|
||
if stage == "predict": | ||
self.data_predict = self._get_dataset(train=False) | ||
|
||
def _get_dataset(self, train: bool): | ||
if train: | ||
return CIFAR100Dataset( | ||
str(self.data_dir), train=True, transform=self.train_transforms | ||
) | ||
else: | ||
return CIFAR100Dataset( | ||
str(self.data_dir), train=False, transform=self.val_transforms | ||
) | ||
|
||
def train_dataloader(self): | ||
return DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True) | ||
|
||
def val_dataloader(self): | ||
return DataLoader(self.data_val, batch_size=self.batch_size) | ||
|
||
def test_dataloader(self): | ||
return DataLoader(self.data_test, batch_size=self.batch_size) | ||
|
||
def predict_dataloader(self): | ||
return DataLoader(self.data_predict, batch_size=self.batch_size) |
Oops, something went wrong.