diff --git a/.gitignore b/.gitignore index 7cfab6d..cb5661c 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,6 @@ checkpoints build dist -.vscode \ No newline at end of file +.vscode + +venv \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6dc30bb..a45408e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] diff --git a/readme.md b/readme.md index b851609..6760929 100644 --- a/readme.md +++ b/readme.md @@ -31,6 +31,10 @@ check [cupy documentation](https://docs.cupy.dev/en/stable/install.html#installi See [npnn WIKI](https://github.com/AIboy996/npnn/wiki). +### Known issues + +See [npnn known-issues](https://github.com/AIboy996/npnn/wiki#known-issues). + ## Work with npnn! > Here we will construct a image classification neural network with npnn. @@ -51,4 +55,16 @@ Construct and Train a neural network on [Fashion-MNIST](https://github.com/zalan - `search.py`: parameters searching - `test.py`: model testing - `viz.py`: visualization -- `utils.py`: some misc function, such as `save_model` \ No newline at end of file +- `utils.py`: some misc function, such as `save_model` + +run `search.py`, you can get a table like: + +no|train_id|accuracy|hidden_size|batch_size|learning_rate|regularization|regular_strength +--|--|--|--|--|--|--|-- +0|2024_0423(1713841292)|0.8306|[384]|3|0.002|None|0.0 +1|2024_0423(1713845802)|0.8145|[384]|3|0.002|l2|0.1 +2|2024_0423(1713849349)|0.8269|[384]|3|0.002|l2|0.01 +3|2024_0423(1713853939)|0.8255|[384]|3|0.002|l2|0.005 +4|2024_0423(1713857657)|0.8373|[384]|3|0.002|l2|0.001 + +train log file and saved model weights can be found in `./logs` and `./checkpoints` folder. \ No newline at end of file diff --git a/search.py b/search.py index 8235cc0..9c09469 100644 --- a/search.py +++ b/search.py @@ -1,53 +1,48 @@ """search hyperparameter""" +import json +from pathlib import Path import pandas as pd -from train import Trainer +from train import train lst = [] -for lr in range(1, 11): - lr *= 0.005 - for hidden_size, batch_size in [ - ([128], 32), - ([256], 8), - ([128, 64], 32), - ([256, 64], 8), - ([256, 128], 8), - ([512, 128], 2), - ([512, 256], 2), - ([512, 384], 1), - ]: - # control batch due to GPU memory limit +# control batch size due to GPU memory limit +for hidden_size, batch_size in [ + # ([512, 256], 2), + # ([256, 128], 4), + # ([256, 64], 8), + # ([128, 64], 8), + ([384], 3), + ([256], 8), + ([128], 16), +]: + for learning_rate in range(1, 11): + learning_rate *= 0.002 for regular in [ (None, 0), - ("l2", 0.5), ("l2", 0.1), ("l2", 0.01), + ("l2", 0.005), ("l2", 0.001), ]: - print(f"searching {lr=}, {hidden_size=}, {regular=}") + + print(f"searching {learning_rate=}, {hidden_size=}, {regular=}") regularization, regular_strength = regular - with Trainer( - hidden_size=hidden_size, - regularization=regularization, - regular_strength=regular_strength, - lr=lr, - batch_size=batch_size, - ) as trainer: - train_hashcode = trainer.train_hashcode - trainer.train() - metric = trainer.test() - train_log = dict( - train_id=train_hashcode, - accuracy=metric, - learning_rate=lr, + train( hidden_size=hidden_size, batch_size=batch_size, - regularization=str(regularization), + learning_rate=learning_rate, + regularization=regularization, regular_strength=regular_strength, ) - lst.append(train_log) - print(train_hashcode, metric) +l = [] +for json_file in Path("./logs").glob("*.json"): + with open(json_file) as f: + train_log = json.load(f) + del train_log["train_loss"] + del train_log["valid_metric"] + l.append(train_log) -df = pd.DataFrame(lst) +df = pd.DataFrame(l) df.to_csv("./search_result.csv") diff --git a/search_result.csv b/search_result.csv new file mode 100644 index 0000000..45fef3b --- /dev/null +++ b/search_result.csv @@ -0,0 +1,151 @@ +,train_id,accuracy,hidden_size,batch_size,learning_rate,regularization,regular_strength +0,2024_0423(1713841292),0.8306,[384],3,0.002,None,0.0 +1,2024_0423(1713845802),0.8145,[384],3,0.002,l2,0.1 +2,2024_0423(1713849349),0.8269,[384],3,0.002,l2,0.01 +3,2024_0423(1713853939),0.8255,[384],3,0.002,l2,0.005 +4,2024_0423(1713857657),0.8373,[384],3,0.002,l2,0.001 +5,2024_0423(1713862340),0.8148,[384],3,0.004,None,0.0 +6,2024_0423(1713865622),0.8189,[384],3,0.004,l2,0.1 +7,2024_0423(1713868861),0.8145,[384],3,0.004,l2,0.01 +8,2024_0423(1713872052),0.8193,[384],3,0.004,l2,0.005 +9,2024_0423(1713875333),0.8198,[384],3,0.004,l2,0.001 +10,2024_0423(1713878617),0.8059,[384],3,0.006,None,0.0 +11,2024_0423(1713881314),0.8093,[384],3,0.006,l2,0.1 +12,2024_0423(1713883910),0.8107,[384],3,0.006,l2,0.01 +13,2024_0423(1713886610),0.8064,[384],3,0.006,l2,0.005 +14,2024_0424(1713889485),0.8148,[384],3,0.006,l2,0.001 +15,2024_0424(1713892116),0.8141,[384],3,0.008,None,0.0 +16,2024_0424(1713894704),0.7973,[384],3,0.008,l2,0.1 +17,2024_0424(1713897199),0.8001,[384],3,0.008,l2,0.01 +18,2024_0424(1713899689),0.8094,[384],3,0.008,l2,0.005 +19,2024_0424(1713902266),0.8041,[384],3,0.008,l2,0.001 +20,2024_0424(1713904845),0.7907,[384],3,0.01,None,0.0 +21,2024_0424(1713907329),0.7873,[384],3,0.01,l2,0.1 +22,2024_0424(1713909749),0.8035,[384],3,0.01,l2,0.01 +23,2024_0424(1713912184),0.8095,[384],3,0.01,l2,0.005 +24,2024_0424(1713914600),0.7893,[384],3,0.01,l2,0.001 +25,2024_0424(1713916963),0.7864,[384],3,0.012,None,0.0 +26,2024_0424(1713919315),0.7911,[384],3,0.012,l2,0.1 +27,2024_0424(1713921700),0.7967,[384],3,0.012,l2,0.01 +28,2024_0424(1713924029),0.7886,[384],3,0.012,l2,0.005 +29,2024_0424(1713926421),0.8024,[384],3,0.012,l2,0.001 +30,2024_0424(1713928781),0.789,[384],3,0.014,None,0.0 +31,2024_0424(1713931121),0.7819,[384],3,0.014,l2,0.1 +32,2024_0424(1713933420),0.7812,[384],3,0.014,l2,0.01 +33,2024_0424(1713935761),0.7869,[384],3,0.014,l2,0.005 +34,2024_0424(1713938099),0.7859,[384],3,0.014,l2,0.001 +35,2024_0424(1713940443),0.7619,[384],3,0.016,None,0.0 +36,2024_0424(1713942672),0.7668,[384],3,0.016,l2,0.1 +37,2024_0424(1713944942),0.7579,[384],3,0.016,l2,0.01 +38,2024_0424(1713947243),0.7777,[384],3,0.016,l2,0.005 +39,2024_0424(1713949533),0.7754,[384],3,0.016,l2,0.001 +40,2024_0424(1713951847),0.7765,[384],3,0.018000000000000002,None,0.0 +41,2024_0424(1713954106),0.7671,[384],3,0.018000000000000002,l2,0.1 +42,2024_0424(1713956375),0.7632,[384],3,0.018000000000000002,l2,0.01 +43,2024_0424(1713958620),0.7807,[384],3,0.018000000000000002,l2,0.005 +44,2024_0424(1713960875),0.7581,[384],3,0.018000000000000002,l2,0.001 +45,2024_0424(1713963130),0.7725,[384],3,0.02,None,0.0 +46,2024_0424(1713965554),0.7469,[384],3,0.02,l2,0.1 +47,2024_0424(1713967908),0.7445,[384],3,0.02,l2,0.01 +48,2024_0424(1713970247),0.733,[384],3,0.02,l2,0.005 +49,2024_0424(1713972509),0.7591,[384],3,0.02,l2,0.001 +50,2024_0425(1713974752),0.8111,[256],8,0.002,None,0.0 +51,2024_0425(1713976375),0.8099,[256],8,0.002,l2,0.1 +52,2024_0425(1713978611),0.8261,[256],8,0.002,l2,0.01 +53,2024_0425(1713980873),0.8253,[256],8,0.002,l2,0.005 +54,2024_0425(1713983095),0.8225,[256],8,0.002,l2,0.001 +55,2024_0425(1713985533),0.8257,[256],8,0.004,None,0.0 +56,2024_0425(1713988329),0.8123,[256],8,0.004,l2,0.1 +57,2024_0425(1713993322),0.8314,[256],8,0.004,l2,0.01 +58,2024_0425(1714001028),0.826,[256],8,0.004,l2,0.005 +59,2024_0425(1714003908),0.8243,[256],8,0.004,l2,0.001 +60,2024_0425(1714009492),0.8246,[256],8,0.006,None,0.0 +61,2024_0425(1714012546),0.8173,[256],8,0.006,l2,0.1 +62,2024_0425(1714018966),0.8124,[256],8,0.006,l2,0.01 +63,2024_0425(1714021889),0.8186,[256],8,0.006,l2,0.005 +64,2024_0425(1714025725),0.8227,[256],8,0.006,l2,0.001 +65,2024_0425(1714029498),0.819,[256],8,0.008,None,0.0 +66,2024_0425(1714033373),0.7978,[256],8,0.008,l2,0.1 +67,2024_0425(1714037214),0.8188,[256],8,0.008,l2,0.01 +68,2024_0425(1714041060),0.8128,[256],8,0.008,l2,0.005 +69,2024_0425(1714044880),0.8195,[256],8,0.008,l2,0.001 +70,2024_0425(1714046886),0.8196,[256],8,0.01,None,0.0 +71,2024_0425(1714049564),0.8026,[256],8,0.01,l2,0.1 +72,2024_0425(1714051169),0.8208,[256],8,0.01,l2,0.01 +73,2024_0425(1714055998),0.8104,[256],8,0.01,l2,0.005 +74,2024_0425(1714058167),0.8091,[256],8,0.01,l2,0.001 +75,2024_0426(1714061643),0.8065,[256],8,0.012,None,0.0 +76,2024_0426(1714064253),0.8001,[256],8,0.012,l2,0.1 +77,2024_0426(1714065976),0.8017,[256],8,0.012,l2,0.01 +78,2024_0426(1714068981),0.8094,[256],8,0.012,l2,0.005 +79,2024_0426(1714070623),0.8018,[256],8,0.012,l2,0.001 +80,2024_0426(1714072808),0.802,[256],8,0.014,None,0.0 +81,2024_0426(1714074398),0.7979,[256],8,0.014,l2,0.1 +82,2024_0426(1714077263),0.811,[256],8,0.014,l2,0.01 +83,2024_0426(1714079748),0.8037,[256],8,0.014,l2,0.005 +84,2024_0426(1714082616),0.8071,[256],8,0.014,l2,0.001 +85,2024_0426(1714084784),0.7924,[256],8,0.016,None,0.0 +86,2024_0426(1714086277),0.7952,[256],8,0.016,l2,0.1 +87,2024_0426(1714088025),0.804,[256],8,0.016,l2,0.01 +88,2024_0426(1714090742),0.7948,[256],8,0.016,l2,0.005 +89,2024_0426(1714092226),0.7985,[256],8,0.016,l2,0.001 +90,2024_0426(1714093704),0.8078,[256],8,0.018000000000000002,None,0.0 +91,2024_0426(1714095674),0.7921,[256],8,0.018000000000000002,l2,0.1 +92,2024_0426(1714097406),0.7948,[256],8,0.018000000000000002,l2,0.01 +93,2024_0426(1714098739),0.7808,[256],8,0.018000000000000002,l2,0.005 +94,2024_0426(1714101139),0.79,[256],8,0.018000000000000002,l2,0.001 +95,2024_0426(1714103605),0.798,[256],8,0.02,None,0.0 +96,2024_0426(1714106089),0.7787,[256],8,0.02,l2,0.1 +97,2024_0426(1714107357),0.785,[256],8,0.02,l2,0.01 +98,2024_0426(1714108682),0.7867,[256],8,0.02,l2,0.005 +99,2024_0426(1714111012),0.7761,[256],8,0.02,l2,0.001 +100,2024_0426(1714113280),0.8187,[128],16,0.002,None,0.0 +101,2024_0426(1714113700),0.8095,[128],16,0.002,l2,0.1 +102,2024_0426(1714114012),0.7948,[128],16,0.002,l2,0.01 +103,2024_0426(1714114321),0.798,[128],16,0.002,l2,0.005 +104,2024_0426(1714114769),0.8141,[128],16,0.002,l2,0.001 +105,2024_0426(1714115192),0.8054,[128],16,0.004,None,0.0 +106,2024_0426(1714115525),0.8271,[128],16,0.004,l2,0.1 +107,2024_0426(1714115872),0.8124,[128],16,0.004,l2,0.01 +108,2024_0426(1714116252),0.8234,[128],16,0.004,l2,0.005 +109,2024_0426(1714116650),0.8365,[128],16,0.004,l2,0.001 +110,2024_0426(1714117224),0.8214,[128],16,0.006,None,0.0 +111,2024_0426(1714117615),0.8125,[128],16,0.006,l2,0.1 +112,2024_0426(1714117963),0.8182,[128],16,0.006,l2,0.01 +113,2024_0426(1714118389),0.8247,[128],16,0.006,l2,0.005 +114,2024_0426(1714118831),0.8258,[128],16,0.006,l2,0.001 +115,2024_0426(1714119206),0.8224,[128],16,0.008,None,0.0 +116,2024_0426(1714119604),0.8075,[128],16,0.008,l2,0.1 +117,2024_0426(1714120032),0.8175,[128],16,0.008,l2,0.01 +118,2024_0426(1714120488),0.8173,[128],16,0.008,l2,0.005 +119,2024_0426(1714120880),0.8158,[128],16,0.008,l2,0.001 +120,2024_0426(1714121354),0.8205,[128],16,0.01,None,0.0 +121,2024_0426(1714121758),0.8096,[128],16,0.01,l2,0.1 +122,2024_0426(1714122109),0.8284,[128],16,0.01,l2,0.01 +123,2024_0426(1714122496),0.8159,[128],16,0.01,l2,0.005 +124,2024_0426(1714122891),0.8125,[128],16,0.01,l2,0.001 +125,2024_0426(1714123326),0.8018,[128],16,0.012,None,0.0 +126,2024_0426(1714123743),0.7958,[128],16,0.012,l2,0.1 +127,2024_0426(1714124100),0.8107,[128],16,0.012,l2,0.01 +128,2024_0426(1714124495),0.8178,[128],16,0.012,l2,0.005 +129,2024_0426(1714124836),0.807,[128],16,0.012,l2,0.001 +130,2024_0426(1714125220),0.8085,[128],16,0.014,None,0.0 +131,2024_0426(1714125609),0.788,[128],16,0.014,l2,0.1 +132,2024_0426(1714125963),0.8093,[128],16,0.014,l2,0.01 +133,2024_0426(1714126321),0.8115,[128],16,0.014,l2,0.005 +134,2024_0426(1714126702),0.8237,[128],16,0.014,l2,0.001 +135,2024_0426(1714127074),0.8035,[128],16,0.016,None,0.0 +136,2024_0426(1714127410),0.7992,[128],16,0.016,l2,0.1 +137,2024_0426(1714127783),0.8055,[128],16,0.016,l2,0.01 +138,2024_0426(1714128140),0.8102,[128],16,0.016,l2,0.005 +139,2024_0426(1714128528),0.8007,[128],16,0.016,l2,0.001 +140,2024_0426(1714128925),0.7936,[128],16,0.018000000000000002,None,0.0 +141,2024_0426(1714129238),0.8033,[128],16,0.018000000000000002,l2,0.1 +142,2024_0426(1714129593),0.812,[128],16,0.018000000000000002,l2,0.01 +143,2024_0426(1714129943),0.8032,[128],16,0.018000000000000002,l2,0.005 +144,2024_0426(1714130284),0.8055,[128],16,0.018000000000000002,l2,0.001 +145,2024_0426(1714130593),0.8042,[128],16,0.02,None,0.0 +146,2024_0426(1714130890),0.776,[128],16,0.02,l2,0.1 +147,2024_0426(1714131196),0.8023,[128],16,0.02,l2,0.01 +148,2024_0426(1714131500),0.8049,[128],16,0.02,l2,0.005 +149,2024_0426(1714131818),0.8099,[128],16,0.02,l2,0.001 diff --git a/test.py b/test.py index 978aa12..97a5055 100644 --- a/test.py +++ b/test.py @@ -26,6 +26,6 @@ def test_model(model, dataset="val"): if __name__ == "__main__": - best_model = load_model(r"checkpoints\2024_0420(1713624674)\best_model.xz") + best_model = load_model(r"checkpoints\2024_0421(1713707080)\best_model.xz") metric = test_model(best_model, dataset="test") print(f"test done, metric = {metric}") diff --git a/train.py b/train.py index 7659db4..f5e561e 100644 --- a/train.py +++ b/train.py @@ -1,8 +1,9 @@ """train model""" -import logging import os import time +import json +import logging from typing import Literal import npnn.functional as F @@ -17,7 +18,7 @@ IMAGE_SIZE = 28 * 28 NUM_CLASS = 10 -TOTAL_EPOCH = 1 +TOTAL_EPOCH = 4 class Trainer: @@ -43,6 +44,7 @@ def __init__( regularization=regularization, regular_strength=regular_strength, ) + # load training data train_images, train_labels = load_mnist("./data", "train") self.images = train_images # trun into one hot @@ -50,10 +52,15 @@ def __init__( # model's last layer is LogSoftmax, so we use NLL Loss function here # this is equivalent to CrossEntropy Loss self.criterion = F.NLL() + + # setup logger date = time.strftime(r"%Y_%m%d") self.train_hashcode = f"{date}({int(time.time())})" self.logger = self.setup_logger() + self.train_loss = [] + self.valid_metric = [] + def setup_logger(self): logger = logging.getLogger() if not os.path.exists("./logs"): @@ -104,6 +111,9 @@ def train_epoch(self, epoch): # do validation each 50 batch if batch % 50 == 1: metric = test_model(self.model, dataset="val") + loss = total_loss / batch + self.train_loss.append(loss) + self.valid_metric.append(metric) if metric > self.best_metric: early_stop_count = 0 self.best_metric = metric @@ -114,35 +124,66 @@ def train_epoch(self, epoch): ) self.best_model_path = file_name self.logger.info( - f"{epoch=}, {batch=}, train loss={total_loss/batch : .4f}, valid metric={metric: .4f}.\n" + f"{epoch=}, {batch=}, train loss={loss : .4f}, valid metric={metric: .4f}.\n" f"Find better model, saved to {file_name}.", ) else: early_stop_count += 1 self.logger.info( - f"{epoch=}, {batch=}, train loss={total_loss/batch : .4f}, valid metric={metric: .4f}" + f"{epoch=}, {batch=}, train loss={loss : .4f}, valid metric={metric: .4f}" ) - if early_stop_count > (20000 // self.batch_size // 50): + if early_stop_count > (15000 // self.batch_size // 50): f"{epoch=}, Early stop since metric have no improvement for {early_stop_count} consecutive batches." break return total_loss / batch def __exit__(self, exc_type, exc_value, traceback): + import gc + # close logging handler file_handler = self.logger.handlers[0] self.logger.removeHandler(file_handler) file_handler.close() - if np.__name__ == 'cupy': + if np.__name__ == "cupy": # cupy free memory np.get_default_memory_pool().free_all_blocks() + gc.collect() def __enter__(self): return self -if __name__ == "__main__": +def train( + hidden_size=[128], + batch_size=16, + learning_rate=0.001, + regularization=None, + regular_strength=0, +): with Trainer( - hidden_size=[256], activation=F.ReLU, regularization=None, batch_size=32 + hidden_size=hidden_size, + activation=F.ReLU, + regularization=regularization, + regular_strength=regular_strength, + lr=learning_rate, + batch_size=batch_size, ) as trainer: trainer.train() - trainer.test() + metric = trainer.test() + train_log = dict( + train_id=trainer.train_hashcode, + accuracy=metric, + hidden_size=hidden_size, + batch_size=batch_size, + learning_rate=learning_rate, + regularization=str(regularization), + regular_strength=regular_strength, + train_loss=trainer.train_loss, + valid_metric=trainer.valid_metric, + ) + with open(f"./logs/{trainer.train_hashcode}.json", "w+") as f: + json.dump(train_log, f) + + +if __name__ == "__main__": + train()