-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
92 lines (77 loc) · 3.08 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# import the necessary packages
from pyimagesearch import config
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2
import os
def prepare_plot(origImage, origMask, predMask, image_name, model):
# initialize our figure
figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
# plot the original image, its mask, and the predicted mask
ax[0].imshow(origImage)
ax[1].imshow(origMask)
ax[2].imshow(predMask)
# set the titles of the subplots
ax[0].set_title("Image")
ax[1].set_title("Original Mask")
ax[2].set_title("Predicted Mask")
# set the layout of the figure and display it
figure.tight_layout()
path_to_saving = os.path.join(config.BASE_OUTPUT, model, image_name)
plt.savefig(path_to_saving)
def make_predictions(model, imagePath):
# set model to evaluation mode
model.eval()
# turn off gradient tracking
with torch.no_grad():
# load the image from disk, swap its color channels, cast it
# to float data type, and scale its pixel values
image = cv2.imread(imagePath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.astype("float32") / 255.0
# resize the image and make a copy of it for visualization
image = cv2.resize(image, (128, 128))
orig = image.copy()
# find the filename and generate the path to ground truth
# mask
filename = imagePath.split(os.path.sep)[-1]
filename_mask = filename.split('.')[0]
filename_mask = filename_mask + '-mask.png'
groundTruthPath = os.path.join(config.MASK_DATASET_PATH, filename_mask)
# load the ground-truth segmentation mask in grayscale mode
# and resize it
gtMask = cv2.imread(groundTruthPath)
gtMask = cv2.resize(gtMask, (config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_HEIGHT))
# make the channel axis to be the leading one, add a batch
# dimension, create a PyTorch tensor, and flash it to the
# current device
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, 0)
image = torch.from_numpy(image).to(config.DEVICE)
# make the prediction, pass the results through the sigmoid
# function, and convert the result to a NumPy array
predMask = model(image).squeeze()
predMask = torch.sigmoid(predMask)
predMask = predMask.cpu().numpy()
# filter out the weak predictions and convert them to integers
# predMask = (predMask >= config.THRESHOLD) * 255
predMask = predMask * 255
print(predMask)
print('mean_mask =', np.mean(predMask))
print(predMask.shape)
predMask = predMask.astype(np.uint8)
# prepare a plot for visualization
prepare_plot(orig, gtMask, predMask, filename, model='unet')
# load the image paths in our testing file and randomly select 10
# image paths
print("[INFO] loading up test image paths...")
imagePaths = open(config.TEST_PATHS).read().strip().split("\n")
imagePaths = np.random.choice(imagePaths, size=5)
# load our model from disk and flash it to the current device
print("[INFO] load up model...")
unet = torch.load(config.MODEL_PATH).to(config.DEVICE)
# iterate over the randomly selected test image paths
for path in imagePaths:
# make predictions and visualize the results
make_predictions(unet, path)