-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathshrink_model.py
75 lines (59 loc) · 2.02 KB
/
shrink_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import tensorflow as tf
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from skimage.transform import resize
#set model path
filepath = './data/weights/RIWA_fullmodel_model'
#TF lite converter
converter = tf.lite.TFLiteConverter.from_saved_model(filepath)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
#save tflite model
tflite_models_dir = pathlib.Path("./tflite_model/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir/"RIWA_feb2023.tflite"
tflite_model_file.write_bytes(tflite_model)
#Sanity check on image
#Load model into TFlite intepreter
interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
interpreter.allocate_tensors()
#Doodleverse standardization using adjusted standard deviation
def standardize(img):
N = np.shape(img)[0] * np.shape(img)[1]
s = np.maximum(np.std(img), 1.0/np.sqrt(N))
m = np.mean(img)
img = (img - m) / s
del m, s, N
#
if np.ndim(img)==2:
img = np.dstack((img,img,img))
return img
#SET THE IMAGE size
pix_dim = 512
imsize = (pix_dim, pix_dim)
#Prep the input
imgp = "./data/images/img_0139.jpg"
img = tf.keras.preprocessing.image.load_img(imgp,target_size = imsize)
img = tf.keras.preprocessing.image.img_to_array(img)
Simg = standardize(img)
test_image = np.expand_dims(Simg,axis=0)
#get & set the tflite model details, make the prediction
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
#print(interpreter.get_output_details())
interpreter.set_tensor(input_index, test_image)
interpreter.invoke()
predictions = interpreter.get_tensor(output_index)
#plot the results - image and then predcition
img1 = plt.imread(imgp)
fig, axs = plt.subplots(1, 2)
axs[0].imshow(img1)
axs[0].grid(False)
pred_sq = predictions.squeeze()
label = np.argmax(pred_sq,-1)
label_resized = resize(label, img1.shape[:2], preserve_range=True)
axs[1].imshow(label_resized)
axs[1].grid(False)
# Show the plot
plt.show()