Skip to content

Commit

Permalink
working training
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed May 11, 2024
1 parent c4672f6 commit 71573e5
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 20 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,6 @@ cython_debug/

# Caches
cache/*
lightning_logs/*
logs/*
lightning_logs/*
slurm-*
60 changes: 44 additions & 16 deletions mnistSimpleCNN/train_adam.py → mnistSimpleCNN/mnist_training.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
""" Train MNIST """

from __future__ import annotations
from typing import Callable, Any

import functools as func

from lightning.pytorch.core.optimizer import LightningOptimizer
import torch
from torch import optim
from torch.utils.data import DataLoader
import torchvision as tv
import torch.nn.functional as F
import torchvision as tv
import lightning as L
import torchmetrics.functional.classification as metrics

from mnistSimpleCNN.models.modelM3 import ModelM3
from mnistSimpleCNN.models.modelM5 import ModelM5
from mnistSimpleCNN.models.modelM7 import ModelM7
from pyrfd import RFD, covariance


class Classifier(L.LightningModule):
Expand All @@ -18,7 +26,7 @@ def __init__(self, model, optimizer=optim.Adam):
self.optimizer = optimizer
self.model = model

def training_step(self, batch, batch_idx):
def training_step(self, batch, *args, **kwargs):
x_in, y_out = batch
prediction: torch.Tensor = self.model(x_in)
loss_value = F.nll_loss(prediction, y_out)
Expand Down Expand Up @@ -66,21 +74,19 @@ def optimizer_step(
self.log(f"learning_rate_{idx}", learning_rate, on_step=True)


def run():
trainer = L.Trainer(
max_epochs=2,
log_every_n_steps=1,
)
def mnist_training():

train_dataset = tv.datasets.MNIST(
root="mnistSimpleCNN/data",
train=True,
transform=tv.transforms.ToTensor(),
)
train_loader = DataLoader(
tv.datasets.MNIST(
root="mnistSimpleCNN/data",
train=True,
transform=tv.transforms.ToTensor(),
),
train_dataset,
batch_size=120,
shuffle=True,
)

test_loader = DataLoader(
tv.datasets.MNIST(
root="mnistSimpleCNN/data",
Expand All @@ -91,10 +97,32 @@ def run():
shuffle=False,
)

model = Classifier(ModelM3())
trainer.fit(model=model, train_dataloaders=train_loader)
trainer.test(model=model, dataloaders=test_loader)
model: torch.nn.Module
for model in [ModelM3, ModelM5, ModelM7]:
cov_model = covariance.SquaredExponential()
cov_model.auto_fit(
model_factory=model,
loss=F.nll_loss,
data=train_dataset,
cache=f"logs/mnist/{model.__name__}/covariance_cache/nll.csv"
)

classifiers = {}
for (name, opt) in {
"RFD": func.partial(RFD, covariance_model=cov_model),
"Adam": optim.Adam,
"SGD": optim.SGD
}.items():
trainer = L.Trainer(
max_epochs=2,
log_every_n_steps=1,
default_root_dir=f"logs/mnist/{model.__name__}/{name}"
)
classifier = Classifier(model(), optimizer=opt)
classifiers[name] = classifier
trainer.fit(model=classifier, train_dataloaders=train_loader)
trainer.test(model=classifier, dataloaders=test_loader)


if __name__ == "__main__":
run()
mnist_training()
178 changes: 177 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pandas = "^2.2.0"
matplotlib = "^3.8.3"
lightning = "^2.2.4"
torchmetrics = "^1.4.0"
tensorboard = "^2.16.2"


[tool.poetry.group.dev.dependencies]
Expand Down
5 changes: 3 additions & 2 deletions sbatch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

#SBATCH --partition single
#SBATCH --ntasks=1
#SBATCH --time=00:02:00
#SBATCH --time=00:20:00
#SBATCH --gres=gpu:1
#SBATCH --mem-per-cpu=20gb

poetry run python test.py
poetry run python mnistSimpleCNN/mnist_training.py

0 comments on commit 71573e5

Please sign in to comment.