-
Notifications
You must be signed in to change notification settings - Fork 4
/
inference.py
38 lines (32 loc) · 1.41 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# -*- coding: utf-8 -*-
from os import path, listdir
from skimage.io import imread
from model import create_model
import numpy as np
import json
import re
def load_model(input_size):
with open('unsharpDetectorSettings.json', 'r') as json_file:
settings = json.load(json_file)
model = create_model(input_size,
settings["l1fc"], settings["l1fs"], settings["l1st"],
settings["l2fc"], settings["l2fs"], settings["l2st"],
settings["l3fc"], settings["l3fs"],
settings["eac_size"],
settings["res_c"], settings["res_fc"], settings["res_fs"])
model.load_weights("unsharpDetectorWeights.hdf5")
return model
def inference(model, img_list):
return model.predict(img_list, batch_size=len(img_list))
if __name__ == "__main__":
filename_regex = re.compile(r".*\.(jpg|JPG|jpeg|JPEG|png|PNG|bmp|BMP)$")
img_path = "validation_data/good/"
filenames = listdir(path.abspath(img_path))
for filename in filenames:
if filename_regex.match(filename):
print("reading " + str(path.join(path.abspath(img_path), filename)))
data = np.array([
imread(path.join(path.abspath(img_path), filename)) / 255
])
trained_model = load_model(data.shape[1:])
print(inference(trained_model, data))