forked from TDay1/ConeNet-v2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
68 lines (50 loc) · 1.65 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from numpy import number
import torch
from torch.utils.data import DataLoader
import os.path
import time
import cv2
import numpy as np
from utils.dataset import ConeSet
from utils.model import ConeNet
from utils.display import display_bboxes
import torchvision
from PIL import Image
def infer():
# Setup nerual net
net = ConeNet()
# ConeNet stored as Cuda
net.to('cuda')
# Load checkpoint
checkpoint = torch.load(f'C:\\Users\\tday\\code\\racing\\checkpoints\\1625.pt')
net.load_state_dict(checkpoint['model_state_dict'])
# Move to CPU for analysis
net.to('cpu')
vidcap = cv2.VideoCapture('./samples/test_clip2.mp4')
success,image = vidcap.read()
count = 0
writer = cv2.VideoWriter('./samples/test_output.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), 24, (624, 624))
while success:
# Prep image
frame = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
im_pil = Image.fromarray(frame)
frame = torchvision.transforms.functional.resize(im_pil, (624, 624))
tensor = torchvision.transforms.functional.pil_to_tensor(frame)
tensor = tensor / 255
tensor = torch.unsqueeze(tensor, 0)
# Inference
net_out = net(tensor)
# Visualise
bbox_image = display_bboxes(tensor[0], None, net_out[0], conf_thresh=0.2)
# save
image_out = cv2.cvtColor( np.array(bbox_image), cv2.COLOR_RGB2BGR)
writer.write( image_out )
# Log to console
count += 1
if count % 100 == 0:
print(count)
# Get next frame
success,image = vidcap.read()
# Release output
writer.release()
infer()