Skip to content

Commit

Permalink
[add] search result
Browse files Browse the repository at this point in the history
  • Loading branch information
AIboy996 committed Apr 26, 2024
1 parent c39ac63 commit d340af9
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 46 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ checkpoints

build
dist
.vscode
.vscode

venv
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
Expand Down
18 changes: 17 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ check [cupy documentation](https://docs.cupy.dev/en/stable/install.html#installi

See [npnn WIKI](https://github.com/AIboy996/npnn/wiki).

### Known issues

See [npnn known-issues](https://github.com/AIboy996/npnn/wiki#known-issues).

## Work with npnn!
> Here we will construct a image classification neural network with npnn.
Expand All @@ -51,4 +55,16 @@ Construct and Train a neural network on [Fashion-MNIST](https://github.com/zalan
- `search.py`: parameters searching
- `test.py`: model testing
- `viz.py`: visualization
- `utils.py`: some misc function, such as `save_model`
- `utils.py`: some misc function, such as `save_model`

run `search.py`, you can get a table like:

no|train_id|accuracy|hidden_size|batch_size|learning_rate|regularization|regular_strength
--|--|--|--|--|--|--|--
0|2024_0423(1713841292)|0.8306|[384]|3|0.002|None|0.0
1|2024_0423(1713845802)|0.8145|[384]|3|0.002|l2|0.1
2|2024_0423(1713849349)|0.8269|[384]|3|0.002|l2|0.01
3|2024_0423(1713853939)|0.8255|[384]|3|0.002|l2|0.005
4|2024_0423(1713857657)|0.8373|[384]|3|0.002|l2|0.001

train log file and saved model weights can be found in `./logs` and `./checkpoints` folder.
63 changes: 29 additions & 34 deletions search.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,48 @@
"""search hyperparameter"""

import json
from pathlib import Path
import pandas as pd

from train import Trainer
from train import train

lst = []
for lr in range(1, 11):
lr *= 0.005
for hidden_size, batch_size in [
([128], 32),
([256], 8),
([128, 64], 32),
([256, 64], 8),
([256, 128], 8),
([512, 128], 2),
([512, 256], 2),
([512, 384], 1),
]:
# control batch due to GPU memory limit
# control batch size due to GPU memory limit
for hidden_size, batch_size in [
# ([512, 256], 2),
# ([256, 128], 4),
# ([256, 64], 8),
# ([128, 64], 8),
([384], 3),
([256], 8),
([128], 16),
]:
for learning_rate in range(1, 11):
learning_rate *= 0.002
for regular in [
(None, 0),
("l2", 0.5),
("l2", 0.1),
("l2", 0.01),
("l2", 0.005),
("l2", 0.001),
]:
print(f"searching {lr=}, {hidden_size=}, {regular=}")

print(f"searching {learning_rate=}, {hidden_size=}, {regular=}")
regularization, regular_strength = regular
with Trainer(
hidden_size=hidden_size,
regularization=regularization,
regular_strength=regular_strength,
lr=lr,
batch_size=batch_size,
) as trainer:
train_hashcode = trainer.train_hashcode
trainer.train()
metric = trainer.test()
train_log = dict(
train_id=train_hashcode,
accuracy=metric,
learning_rate=lr,
train(
hidden_size=hidden_size,
batch_size=batch_size,
regularization=str(regularization),
learning_rate=learning_rate,
regularization=regularization,
regular_strength=regular_strength,
)
lst.append(train_log)
print(train_hashcode, metric)
l = []
for json_file in Path("./logs").glob("*.json"):
with open(json_file) as f:
train_log = json.load(f)
del train_log["train_loss"]
del train_log["valid_metric"]
l.append(train_log)

df = pd.DataFrame(lst)
df = pd.DataFrame(l)
df.to_csv("./search_result.csv")
151 changes: 151 additions & 0 deletions search_result.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
,train_id,accuracy,hidden_size,batch_size,learning_rate,regularization,regular_strength
0,2024_0423(1713841292),0.8306,[384],3,0.002,None,0.0
1,2024_0423(1713845802),0.8145,[384],3,0.002,l2,0.1
2,2024_0423(1713849349),0.8269,[384],3,0.002,l2,0.01
3,2024_0423(1713853939),0.8255,[384],3,0.002,l2,0.005
4,2024_0423(1713857657),0.8373,[384],3,0.002,l2,0.001
5,2024_0423(1713862340),0.8148,[384],3,0.004,None,0.0
6,2024_0423(1713865622),0.8189,[384],3,0.004,l2,0.1
7,2024_0423(1713868861),0.8145,[384],3,0.004,l2,0.01
8,2024_0423(1713872052),0.8193,[384],3,0.004,l2,0.005
9,2024_0423(1713875333),0.8198,[384],3,0.004,l2,0.001
10,2024_0423(1713878617),0.8059,[384],3,0.006,None,0.0
11,2024_0423(1713881314),0.8093,[384],3,0.006,l2,0.1
12,2024_0423(1713883910),0.8107,[384],3,0.006,l2,0.01
13,2024_0423(1713886610),0.8064,[384],3,0.006,l2,0.005
14,2024_0424(1713889485),0.8148,[384],3,0.006,l2,0.001
15,2024_0424(1713892116),0.8141,[384],3,0.008,None,0.0
16,2024_0424(1713894704),0.7973,[384],3,0.008,l2,0.1
17,2024_0424(1713897199),0.8001,[384],3,0.008,l2,0.01
18,2024_0424(1713899689),0.8094,[384],3,0.008,l2,0.005
19,2024_0424(1713902266),0.8041,[384],3,0.008,l2,0.001
20,2024_0424(1713904845),0.7907,[384],3,0.01,None,0.0
21,2024_0424(1713907329),0.7873,[384],3,0.01,l2,0.1
22,2024_0424(1713909749),0.8035,[384],3,0.01,l2,0.01
23,2024_0424(1713912184),0.8095,[384],3,0.01,l2,0.005
24,2024_0424(1713914600),0.7893,[384],3,0.01,l2,0.001
25,2024_0424(1713916963),0.7864,[384],3,0.012,None,0.0
26,2024_0424(1713919315),0.7911,[384],3,0.012,l2,0.1
27,2024_0424(1713921700),0.7967,[384],3,0.012,l2,0.01
28,2024_0424(1713924029),0.7886,[384],3,0.012,l2,0.005
29,2024_0424(1713926421),0.8024,[384],3,0.012,l2,0.001
30,2024_0424(1713928781),0.789,[384],3,0.014,None,0.0
31,2024_0424(1713931121),0.7819,[384],3,0.014,l2,0.1
32,2024_0424(1713933420),0.7812,[384],3,0.014,l2,0.01
33,2024_0424(1713935761),0.7869,[384],3,0.014,l2,0.005
34,2024_0424(1713938099),0.7859,[384],3,0.014,l2,0.001
35,2024_0424(1713940443),0.7619,[384],3,0.016,None,0.0
36,2024_0424(1713942672),0.7668,[384],3,0.016,l2,0.1
37,2024_0424(1713944942),0.7579,[384],3,0.016,l2,0.01
38,2024_0424(1713947243),0.7777,[384],3,0.016,l2,0.005
39,2024_0424(1713949533),0.7754,[384],3,0.016,l2,0.001
40,2024_0424(1713951847),0.7765,[384],3,0.018000000000000002,None,0.0
41,2024_0424(1713954106),0.7671,[384],3,0.018000000000000002,l2,0.1
42,2024_0424(1713956375),0.7632,[384],3,0.018000000000000002,l2,0.01
43,2024_0424(1713958620),0.7807,[384],3,0.018000000000000002,l2,0.005
44,2024_0424(1713960875),0.7581,[384],3,0.018000000000000002,l2,0.001
45,2024_0424(1713963130),0.7725,[384],3,0.02,None,0.0
46,2024_0424(1713965554),0.7469,[384],3,0.02,l2,0.1
47,2024_0424(1713967908),0.7445,[384],3,0.02,l2,0.01
48,2024_0424(1713970247),0.733,[384],3,0.02,l2,0.005
49,2024_0424(1713972509),0.7591,[384],3,0.02,l2,0.001
50,2024_0425(1713974752),0.8111,[256],8,0.002,None,0.0
51,2024_0425(1713976375),0.8099,[256],8,0.002,l2,0.1
52,2024_0425(1713978611),0.8261,[256],8,0.002,l2,0.01
53,2024_0425(1713980873),0.8253,[256],8,0.002,l2,0.005
54,2024_0425(1713983095),0.8225,[256],8,0.002,l2,0.001
55,2024_0425(1713985533),0.8257,[256],8,0.004,None,0.0
56,2024_0425(1713988329),0.8123,[256],8,0.004,l2,0.1
57,2024_0425(1713993322),0.8314,[256],8,0.004,l2,0.01
58,2024_0425(1714001028),0.826,[256],8,0.004,l2,0.005
59,2024_0425(1714003908),0.8243,[256],8,0.004,l2,0.001
60,2024_0425(1714009492),0.8246,[256],8,0.006,None,0.0
61,2024_0425(1714012546),0.8173,[256],8,0.006,l2,0.1
62,2024_0425(1714018966),0.8124,[256],8,0.006,l2,0.01
63,2024_0425(1714021889),0.8186,[256],8,0.006,l2,0.005
64,2024_0425(1714025725),0.8227,[256],8,0.006,l2,0.001
65,2024_0425(1714029498),0.819,[256],8,0.008,None,0.0
66,2024_0425(1714033373),0.7978,[256],8,0.008,l2,0.1
67,2024_0425(1714037214),0.8188,[256],8,0.008,l2,0.01
68,2024_0425(1714041060),0.8128,[256],8,0.008,l2,0.005
69,2024_0425(1714044880),0.8195,[256],8,0.008,l2,0.001
70,2024_0425(1714046886),0.8196,[256],8,0.01,None,0.0
71,2024_0425(1714049564),0.8026,[256],8,0.01,l2,0.1
72,2024_0425(1714051169),0.8208,[256],8,0.01,l2,0.01
73,2024_0425(1714055998),0.8104,[256],8,0.01,l2,0.005
74,2024_0425(1714058167),0.8091,[256],8,0.01,l2,0.001
75,2024_0426(1714061643),0.8065,[256],8,0.012,None,0.0
76,2024_0426(1714064253),0.8001,[256],8,0.012,l2,0.1
77,2024_0426(1714065976),0.8017,[256],8,0.012,l2,0.01
78,2024_0426(1714068981),0.8094,[256],8,0.012,l2,0.005
79,2024_0426(1714070623),0.8018,[256],8,0.012,l2,0.001
80,2024_0426(1714072808),0.802,[256],8,0.014,None,0.0
81,2024_0426(1714074398),0.7979,[256],8,0.014,l2,0.1
82,2024_0426(1714077263),0.811,[256],8,0.014,l2,0.01
83,2024_0426(1714079748),0.8037,[256],8,0.014,l2,0.005
84,2024_0426(1714082616),0.8071,[256],8,0.014,l2,0.001
85,2024_0426(1714084784),0.7924,[256],8,0.016,None,0.0
86,2024_0426(1714086277),0.7952,[256],8,0.016,l2,0.1
87,2024_0426(1714088025),0.804,[256],8,0.016,l2,0.01
88,2024_0426(1714090742),0.7948,[256],8,0.016,l2,0.005
89,2024_0426(1714092226),0.7985,[256],8,0.016,l2,0.001
90,2024_0426(1714093704),0.8078,[256],8,0.018000000000000002,None,0.0
91,2024_0426(1714095674),0.7921,[256],8,0.018000000000000002,l2,0.1
92,2024_0426(1714097406),0.7948,[256],8,0.018000000000000002,l2,0.01
93,2024_0426(1714098739),0.7808,[256],8,0.018000000000000002,l2,0.005
94,2024_0426(1714101139),0.79,[256],8,0.018000000000000002,l2,0.001
95,2024_0426(1714103605),0.798,[256],8,0.02,None,0.0
96,2024_0426(1714106089),0.7787,[256],8,0.02,l2,0.1
97,2024_0426(1714107357),0.785,[256],8,0.02,l2,0.01
98,2024_0426(1714108682),0.7867,[256],8,0.02,l2,0.005
99,2024_0426(1714111012),0.7761,[256],8,0.02,l2,0.001
100,2024_0426(1714113280),0.8187,[128],16,0.002,None,0.0
101,2024_0426(1714113700),0.8095,[128],16,0.002,l2,0.1
102,2024_0426(1714114012),0.7948,[128],16,0.002,l2,0.01
103,2024_0426(1714114321),0.798,[128],16,0.002,l2,0.005
104,2024_0426(1714114769),0.8141,[128],16,0.002,l2,0.001
105,2024_0426(1714115192),0.8054,[128],16,0.004,None,0.0
106,2024_0426(1714115525),0.8271,[128],16,0.004,l2,0.1
107,2024_0426(1714115872),0.8124,[128],16,0.004,l2,0.01
108,2024_0426(1714116252),0.8234,[128],16,0.004,l2,0.005
109,2024_0426(1714116650),0.8365,[128],16,0.004,l2,0.001
110,2024_0426(1714117224),0.8214,[128],16,0.006,None,0.0
111,2024_0426(1714117615),0.8125,[128],16,0.006,l2,0.1
112,2024_0426(1714117963),0.8182,[128],16,0.006,l2,0.01
113,2024_0426(1714118389),0.8247,[128],16,0.006,l2,0.005
114,2024_0426(1714118831),0.8258,[128],16,0.006,l2,0.001
115,2024_0426(1714119206),0.8224,[128],16,0.008,None,0.0
116,2024_0426(1714119604),0.8075,[128],16,0.008,l2,0.1
117,2024_0426(1714120032),0.8175,[128],16,0.008,l2,0.01
118,2024_0426(1714120488),0.8173,[128],16,0.008,l2,0.005
119,2024_0426(1714120880),0.8158,[128],16,0.008,l2,0.001
120,2024_0426(1714121354),0.8205,[128],16,0.01,None,0.0
121,2024_0426(1714121758),0.8096,[128],16,0.01,l2,0.1
122,2024_0426(1714122109),0.8284,[128],16,0.01,l2,0.01
123,2024_0426(1714122496),0.8159,[128],16,0.01,l2,0.005
124,2024_0426(1714122891),0.8125,[128],16,0.01,l2,0.001
125,2024_0426(1714123326),0.8018,[128],16,0.012,None,0.0
126,2024_0426(1714123743),0.7958,[128],16,0.012,l2,0.1
127,2024_0426(1714124100),0.8107,[128],16,0.012,l2,0.01
128,2024_0426(1714124495),0.8178,[128],16,0.012,l2,0.005
129,2024_0426(1714124836),0.807,[128],16,0.012,l2,0.001
130,2024_0426(1714125220),0.8085,[128],16,0.014,None,0.0
131,2024_0426(1714125609),0.788,[128],16,0.014,l2,0.1
132,2024_0426(1714125963),0.8093,[128],16,0.014,l2,0.01
133,2024_0426(1714126321),0.8115,[128],16,0.014,l2,0.005
134,2024_0426(1714126702),0.8237,[128],16,0.014,l2,0.001
135,2024_0426(1714127074),0.8035,[128],16,0.016,None,0.0
136,2024_0426(1714127410),0.7992,[128],16,0.016,l2,0.1
137,2024_0426(1714127783),0.8055,[128],16,0.016,l2,0.01
138,2024_0426(1714128140),0.8102,[128],16,0.016,l2,0.005
139,2024_0426(1714128528),0.8007,[128],16,0.016,l2,0.001
140,2024_0426(1714128925),0.7936,[128],16,0.018000000000000002,None,0.0
141,2024_0426(1714129238),0.8033,[128],16,0.018000000000000002,l2,0.1
142,2024_0426(1714129593),0.812,[128],16,0.018000000000000002,l2,0.01
143,2024_0426(1714129943),0.8032,[128],16,0.018000000000000002,l2,0.005
144,2024_0426(1714130284),0.8055,[128],16,0.018000000000000002,l2,0.001
145,2024_0426(1714130593),0.8042,[128],16,0.02,None,0.0
146,2024_0426(1714130890),0.776,[128],16,0.02,l2,0.1
147,2024_0426(1714131196),0.8023,[128],16,0.02,l2,0.01
148,2024_0426(1714131500),0.8049,[128],16,0.02,l2,0.005
149,2024_0426(1714131818),0.8099,[128],16,0.02,l2,0.001
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def test_model(model, dataset="val"):


if __name__ == "__main__":
best_model = load_model(r"checkpoints\2024_0420(1713624674)\best_model.xz")
best_model = load_model(r"checkpoints\2024_0421(1713707080)\best_model.xz")
metric = test_model(best_model, dataset="test")
print(f"test done, metric = {metric}")
Loading

0 comments on commit d340af9

Please sign in to comment.