diff --git a/.gitignore b/.gitignore index ba8fa6b..6cdaa03 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,6 @@ cython_debug/ # Caches cache/* -lightning_logs/* \ No newline at end of file +logs/* +lightning_logs/* +slurm-* \ No newline at end of file diff --git a/mnistSimpleCNN/train_adam.py b/mnistSimpleCNN/mnist_training.py similarity index 66% rename from mnistSimpleCNN/train_adam.py rename to mnistSimpleCNN/mnist_training.py index ef956ed..d7cd7c8 100644 --- a/mnistSimpleCNN/train_adam.py +++ b/mnistSimpleCNN/mnist_training.py @@ -1,15 +1,23 @@ +""" Train MNIST """ + from __future__ import annotations from typing import Callable, Any + +import functools as func + from lightning.pytorch.core.optimizer import LightningOptimizer import torch from torch import optim from torch.utils.data import DataLoader -import torchvision as tv import torch.nn.functional as F +import torchvision as tv import lightning as L import torchmetrics.functional.classification as metrics from mnistSimpleCNN.models.modelM3 import ModelM3 +from mnistSimpleCNN.models.modelM5 import ModelM5 +from mnistSimpleCNN.models.modelM7 import ModelM7 +from pyrfd import RFD, covariance class Classifier(L.LightningModule): @@ -18,7 +26,7 @@ def __init__(self, model, optimizer=optim.Adam): self.optimizer = optimizer self.model = model - def training_step(self, batch, batch_idx): + def training_step(self, batch, *args, **kwargs): x_in, y_out = batch prediction: torch.Tensor = self.model(x_in) loss_value = F.nll_loss(prediction, y_out) @@ -66,21 +74,19 @@ def optimizer_step( self.log(f"learning_rate_{idx}", learning_rate, on_step=True) -def run(): - trainer = L.Trainer( - max_epochs=2, - log_every_n_steps=1, - ) +def mnist_training(): + train_dataset = tv.datasets.MNIST( + root="mnistSimpleCNN/data", + train=True, + transform=tv.transforms.ToTensor(), + ) train_loader = DataLoader( - tv.datasets.MNIST( - root="mnistSimpleCNN/data", - train=True, - transform=tv.transforms.ToTensor(), - ), + train_dataset, batch_size=120, shuffle=True, ) + test_loader = DataLoader( tv.datasets.MNIST( root="mnistSimpleCNN/data", @@ -91,10 +97,32 @@ def run(): shuffle=False, ) - model = Classifier(ModelM3()) - trainer.fit(model=model, train_dataloaders=train_loader) - trainer.test(model=model, dataloaders=test_loader) + model: torch.nn.Module + for model in [ModelM3, ModelM5, ModelM7]: + cov_model = covariance.SquaredExponential() + cov_model.auto_fit( + model_factory=model, + loss=F.nll_loss, + data=train_dataset, + cache=f"logs/mnist/{model.__name__}/covariance_cache/nll.csv" + ) + + classifiers = {} + for (name, opt) in { + "RFD": func.partial(RFD, covariance_model=cov_model), + "Adam": optim.Adam, + "SGD": optim.SGD + }.items(): + trainer = L.Trainer( + max_epochs=2, + log_every_n_steps=1, + default_root_dir=f"logs/mnist/{model.__name__}/{name}" + ) + classifier = Classifier(model(), optimizer=opt) + classifiers[name] = classifier + trainer.fit(model=classifier, train_dataloaders=train_loader) + trainer.test(model=classifier, dataloaders=test_loader) if __name__ == "__main__": - run() + mnist_training() diff --git a/poetry.lock b/poetry.lock index ffbbd83..c61938d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "absl-py" +version = "2.1.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +optional = false +python-versions = ">=3.7" +files = [ + {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, + {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, +] + [[package]] name = "aiohttp" version = "3.9.5" @@ -553,6 +564,64 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "grpcio" +version = "1.63.0" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, + {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, + {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, + {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, + {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, + {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, + {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, + {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, + {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, + {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, + {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, + {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, + {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, + {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, + {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, + {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, + {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, + {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, + {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, + {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, + {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.63.0)"] + [[package]] name = "idna" version = "3.7" @@ -564,6 +633,25 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] +[[package]] +name = "importlib-metadata" +version = "7.1.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] + [[package]] name = "importlib-resources" version = "6.1.2" @@ -823,6 +911,24 @@ cli = ["fire"] docs = ["requests (>=2.0.0)"] typing = ["mypy (>=1.0.0)", "types-setuptools"] +[[package]] +name = "markdown" +version = "3.6" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, + {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, +] + +[package.dependencies] +importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -1508,6 +1614,26 @@ files = [ [package.dependencies] colorama = "*" +[[package]] +name = "protobuf" +version = "5.26.1" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "protobuf-5.26.1-cp310-abi3-win32.whl", hash = "sha256:3c388ea6ddfe735f8cf69e3f7dc7611e73107b60bdfcf5d0f024c3ccd3794e23"}, + {file = "protobuf-5.26.1-cp310-abi3-win_amd64.whl", hash = "sha256:e6039957449cb918f331d32ffafa8eb9255769c96aa0560d9a5bf0b4e00a2a33"}, + {file = "protobuf-5.26.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:38aa5f535721d5bb99861166c445c4105c4e285c765fbb2ac10f116e32dcd46d"}, + {file = "protobuf-5.26.1-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:fbfe61e7ee8c1860855696e3ac6cfd1b01af5498facc6834fcc345c9684fb2ca"}, + {file = "protobuf-5.26.1-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:f7417703f841167e5a27d48be13389d52ad705ec09eade63dfc3180a959215d7"}, + {file = "protobuf-5.26.1-cp38-cp38-win32.whl", hash = "sha256:d693d2504ca96750d92d9de8a103102dd648fda04540495535f0fec7577ed8fc"}, + {file = "protobuf-5.26.1-cp38-cp38-win_amd64.whl", hash = "sha256:9b557c317ebe6836835ec4ef74ec3e994ad0894ea424314ad3552bc6e8835b4e"}, + {file = "protobuf-5.26.1-cp39-cp39-win32.whl", hash = "sha256:b9ba3ca83c2e31219ffbeb9d76b63aad35a3eb1544170c55336993d7a18ae72c"}, + {file = "protobuf-5.26.1-cp39-cp39-win_amd64.whl", hash = "sha256:7ee014c2c87582e101d6b54260af03b6596728505c79f17c8586e7523aaa8f8c"}, + {file = "protobuf-5.26.1-py3-none-any.whl", hash = "sha256:da612f2720c0183417194eeaa2523215c4fcc1a1949772dc65f05047e08d5932"}, + {file = "protobuf-5.26.1.tar.gz", hash = "sha256:8ca2a1d97c290ec7b16e4e5dff2e5ae150cc1582f55b5ab300d45cb0dfa90e51"}, +] + [[package]] name = "pycodestyle" version = "2.11.1" @@ -1837,6 +1963,39 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tensorboard" +version = "2.16.2" +description = "TensorBoard lets you watch Tensors Flow" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, +] + +[package.dependencies] +absl-py = ">=0.4" +grpcio = ">=1.48.2" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +protobuf = ">=3.19.6,<4.24.0 || >4.24.0" +setuptools = ">=41.0.0" +six = ">1.9" +tensorboard-data-server = ">=0.7.0,<0.8.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +description = "Fast data loading for TensorBoard" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, + {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, + {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, +] + [[package]] name = "threadpoolctl" version = "3.3.0" @@ -2075,6 +2234,23 @@ files = [ {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, ] +[[package]] +name = "werkzeug" +version = "3.0.3" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, + {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "yarl" version = "1.9.4" @@ -2196,4 +2372,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "5acb029b37c2a6ba63d098b658e090ff86c02165f54f76d5302c7b3a5309846d" +content-hash = "ee13d7a153a82d8cc208d8ac39a9f01f8a28b49eb1064612172ac309c91f588b" diff --git a/pyproject.toml b/pyproject.toml index 71c2588..dcc982e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ pandas = "^2.2.0" matplotlib = "^3.8.3" lightning = "^2.2.4" torchmetrics = "^1.4.0" +tensorboard = "^2.16.2" [tool.poetry.group.dev.dependencies] diff --git a/sbatch.sh b/sbatch.sh index a0e7c9d..1c0a915 100755 --- a/sbatch.sh +++ b/sbatch.sh @@ -2,7 +2,8 @@ #SBATCH --partition single #SBATCH --ntasks=1 -#SBATCH --time=00:02:00 +#SBATCH --time=00:20:00 #SBATCH --gres=gpu:1 +#SBATCH --mem-per-cpu=20gb -poetry run python test.py +poetry run python mnistSimpleCNN/mnist_training.py