Skip to content

Commit

Permalink
move plotting to its own space
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed May 13, 2024
1 parent fdd1a55 commit 56dda0c
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 42 deletions.
79 changes: 44 additions & 35 deletions benchmarking/classification/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,81 @@

from typing import Dict, Any

from torch import optim
from torch import optim, nn
import torch.nn.functional as F
import lightning as L

from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger

# pylint: disable=import-error,wrong-import-order
from benchmarking.classification.cifar.data import CIFAR100
from benchmarking.classification.cifar.models import Resnet18
from mnist.data import MNIST
from mnist.models import CNN3, CNN5, CNN7 # pylint: disable=unused-import
from classifier import Classifier

from pyrfd import RFD, covariance


def train_classic(
problem,
opt: optim.Optimizer,
hyperparameters: Dict[str, Any],
):
"""Train a Classifier with RFD"""
def train(problem, opt: optim.Optimizer, hyperparameters):
"""Train a Classifier """

data: L.LightningDataModule = problem["dataset"](batch_size=problem["batch_size"])
classifier = Classifier(problem["model"](), optimizer=opt, **hyperparameters)

trainer = trainer_from_problem(problem, opt_name=opt.__name__, hyperparameters=hyperparameters)
trainer.fit(classifier, data)
trainer.test(classifier, data)


def train_rfd(problem, hyperparameters):
"""Train a Classifier with RFD"""
data: L.LightningDataModule = problem["dataset"](batch_size=problem["batch_size"])
data.prepare_data()
data.setup("fit")


classifier = Classifier(
problem["model"](),
optimizer=RFD,
optimizer=opt,
loss=problem["loss"],
**hyperparameters,
)

trainer = trainer_from_problem(problem, opt_name="RFD", hyperparameters=hyperparameters)
trainer = trainer_from_problem(problem, opt_name=opt.__name__, hyperparameters=hyperparameters)
trainer.fit(classifier, data)
trainer.test(classifier, data)


def trainer_from_problem(problem, opt_name, hyperparameters):
problem_id = f"{problem['dataset'].__name__}_{problem['model'].__name__}_b={problem['batch_size']}"

hparams = "_".join([f"{key}={value}" for key, value in hyperparameters.items()])
name = problem_id + "/" + opt_name + "(" + hparams + ")"
name = problem_id + "/" + opt_name + "(" + hparams + ")" + f"/seed={problem['seed']}"

tlogger = TensorBoardLogger("logs/TensorBoard", name=name)
csvlogger = CSVLogger("logs", name=name)

L.seed_everything(problem["seed"], workers=True)

return L.Trainer(
max_epochs=30,
log_every_n_steps=1,
**problem["trainer_params"],
logger=[tlogger, csvlogger],
)


def main():
problem = {
PROBLEMS = {
"MNIST_CNN7" : {
"dataset": MNIST,
"model": CNN7,
"batch_size": 1000,
}
"loss": F.nll_loss,
"batch_size": 1024,
"seed": 42,
"trainer_params": {
"max_epochs": 30,
"log_every_n_steps": 1,
}
},
"CIFAR100_resnet18": {
"dataset": CIFAR100,
"model": Resnet18,
"loss": nn.CrossEntropyLoss(label_smoothing=0),
"batch_size": 1024,
"seed": 42,
"trainer_params": {
"max_epochs": 50,
"log_every_n_steps": 1,
}
},
}

def main():
problem = PROBLEMS["CIFAR100_resnet18"]

# fit covariance model
data: L.LightningDataModule = problem["dataset"](batch_size=problem["batch_size"])
Expand All @@ -84,22 +91,24 @@ def main():
)
# ------

train_rfd(
train(
problem,
opt=RFD,
hyperparameters={
"covariance_model": covariance_model,
}
)

train_rfd(
train(
problem,
opt=RFD,
hyperparameters={
"covariance_model": covariance_model,
"b_size_inv": 1/problem["batch_size"],
}
)

train_classic(
train(
problem,
opt=optim.SGD,
hyperparameters={
Expand All @@ -108,7 +117,7 @@ def main():
},
)

train_classic(
train(
problem,
opt=optim.Adam,
hyperparameters={
Expand Down
4 changes: 2 additions & 2 deletions benchmarking/classification/cifar/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@


class CIFAR100(L.LightningDataModule):
"""Represents the CIFAR10 dataset."""
"""Represents the CIFAR100 dataset."""

def __init__(self, data_dir: str = " ./data", batch_size: int = 100) -> None:
def __init__(self, data_dir: str = "./data/CIFAR", batch_size: int = 100) -> None:
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
Expand Down
2 changes: 1 addition & 1 deletion benchmarking/classification/cifar/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class Resnet18(nn.Module):
""" Resnet18 model for CIFAR-100 (modifactions based on FOB benchmark) """
def __init__(self):
super().__init__()
self.model = resnet18(num_classes=100, pretrained=False)
self.model = resnet18(num_classes=100, weights=None)
# 7x7 conv is too large for 32x32 images
self.model.conv1 = nn.Conv2d(
in_channels=3, # rgb color
Expand Down
7 changes: 4 additions & 3 deletions benchmarking/classification/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
class Classifier(L.LightningModule):
"""Abstract Classifier for training and testing a model on a classification dataset"""

def __init__(self, model, optimizer, **hyperameters):
def __init__(self, model, optimizer, loss=F.nll_loss, **hyperameters):
super().__init__()
self.optimizer = optimizer
self.model = model
self.loss = loss
self.hyperparemeters = hyperameters
self.save_hyperparameters(ignore=["model", "optimizer" ])

Expand All @@ -27,7 +28,7 @@ def training_step(self, batch, *args, **kwargs):
"""Apply Negative Log-Likelihood loss to the model output and logging"""
x_in, y_out = batch
prediction: torch.Tensor = self.model(x_in)
loss_value = F.nll_loss(prediction, y_out)
loss_value = self.loss(prediction, y_out)
self.log("train/loss", loss_value, prog_bar=True)
acc = metrics.multiclass_accuracy(prediction, y_out, prediction.size(dim=1))
self.log("train/accuracy", acc, on_epoch=True)
Expand All @@ -37,7 +38,7 @@ def validation_step(self, batch):
"""Apply Negative Log-Likelihood loss to the model output and logging"""
x_in, y_out = batch
prediction: torch.Tensor = self.model(x_in)
loss_value = F.nll_loss(prediction, y_out)
loss_value = self.loss(prediction, y_out)
self.log("val/loss", loss_value, on_epoch=True)

acc = metrics.multiclass_accuracy(prediction, y_out, prediction.size(dim=1))
Expand Down
10 changes: 10 additions & 0 deletions benchmarking/classification/mnist/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ def __init__(self, data_dir: str = "./data", batch_size: int = 32):
]
)

# # Possible data augmentation
# self.train_transform = transforms.Compose(
# [
# transforms.RandomRotation(20),
# transforms.RandomAffine(0, translate=(0.2, 0.2)),
# self.transform,
# ]
# )


def prepare_data(self):
# download
MNISTDataset(self.data_dir, train=True, download=True)
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion sbatch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#SBATCH --partition single
#SBATCH --ntasks=1
#SBATCH --time=00:30:00
#SBATCH --time=01:00:00
#SBATCH --gres=gpu:1
#SBATCH --mem-per-cpu=20gb

Expand Down

0 comments on commit 56dda0c

Please sign in to comment.