-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
37 lines (26 loc) · 1 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
from src.DataLoader import FashionMNISTDataloader
from src.Loss import CrossEntropyLoss
from src.MLPModel import MLPModel
dataloaders_kwargs = {
"path_dir": "dataset",
"batch_size": 32,
}
ckpt_path = "models/model_epoch_100.pkl"
def main():
dataloader = FashionMNISTDataloader(**dataloaders_kwargs)
model = MLPModel()
model.load_model_dict(ckpt_path) # 从已经训练好的权重加载模型
loss = CrossEntropyLoss()
total_loss = 0
total_acc = 0
for x_batch, y_batch in dataloader.generate_test_batch():
y_pred = model.forward(x_batch)
total_acc += np.sum(np.argmax(y_pred, axis=1) == np.argmax(y_batch, axis=1))
ce_loss = loss.forward(y_pred, y_batch)
total_loss += ce_loss * len(x_batch)
test_loss = total_loss / len(dataloader.y_test)
test_acc = total_acc / len(dataloader.y_test)
print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f} | Checkpoint: {ckpt_path} | ")
if __name__ == "__main__":
main()