diff --git a/lib/tracker/siamrpn.py b/lib/tracker/siamrpn.py index a2f21e8..64f570e 100644 --- a/lib/tracker/siamrpn.py +++ b/lib/tracker/siamrpn.py @@ -36,13 +36,14 @@ def init(self, im, target_pos, target_sz, model, hp=None): p.renew() # for vot17 or vot18: from siamrpn released - if '2017' in self.info.dataset: - if ((target_sz[0] * target_sz[1]) / float(state['im_h'] * state['im_w'])) < 0.004: - p.instance_size = 287 - p.renew() - else: - p.instance_size = 271 - p.renew() + if self.info.dataset: + if '2017' in self.info.dataset: + if ((target_sz[0] * target_sz[1]) / float(state['im_h'] * state['im_w'])) < 0.004: + p.instance_size = 287 + p.renew() + else: + p.instance_size = 271 + p.renew() # param tune if hp: diff --git a/lib/tutorials/test.md b/lib/tutorials/test.md index 13249ff..d5027df 100644 --- a/lib/tutorials/test.md +++ b/lib/tutorials/test.md @@ -11,6 +11,12 @@ python siamese_tracking/run_video.py --arch SiamRPNRes22 --resume snapshot/CIRes - The opencv version here is 4.1.0.25, and older versions may be not friendly to some functions. - If you try to conduct this project on a specific tracking task, eg. pedestrian tracking, it's suggested that you can tuning hyper-parameters on your collected data with our tuning toolkit detailed below. +## Test on images dir + +```python +python siamese_tracking/run_video.py --arch SiamRPNRes22 --resume snapshot/CIResNet22_RPN.pth --images path/to/jpg/images +``` + ## Test through webcam eg, ``` diff --git a/lib/utils/utils.py b/lib/utils/utils.py index 15844c8..54def52 100644 --- a/lib/utils/utils.py +++ b/lib/utils/utils.py @@ -226,11 +226,14 @@ def remove_prefix(state_dict, prefix): return {f(key): value for key, value in state_dict.items()} -def load_pretrain(model, pretrained_path): +def load_pretrain(model, pretrained_path, to_cpu=False): print('load pretrained model from {}'.format(pretrained_path)) - device = torch.cuda.current_device() - pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) + if not to_cpu: + device = torch.cuda.current_device() + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) + else: + pretrained_dict = torch.load(pretrained_path, map_location='cpu') if "state_dict" in pretrained_dict.keys(): pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') @@ -238,7 +241,10 @@ def load_pretrain(model, pretrained_path): pretrained_dict = remove_prefix(pretrained_dict, 'module.') check_keys(model, pretrained_dict) model.load_state_dict(pretrained_dict, strict=False) - return model + + if not to_cpu: + return model + return model.to('cpu') Corner = namedtuple('Corner', 'x1 y1 x2 y2') diff --git a/siamese_tracking/run_video.py b/siamese_tracking/run_video.py index 7ced81d..557e8c3 100644 --- a/siamese_tracking/run_video.py +++ b/siamese_tracking/run_video.py @@ -5,7 +5,7 @@ # Email: zhangzhipeng2017@ia.ac.cn # Detail: test siamese on a specific video (provide init bbox and video file) # ------------------------------------------------------------------------------ - +import time import _init_paths import os import cv2 @@ -22,7 +22,6 @@ from easydict import EasyDict as edict from utils.utils import load_pretrain, cxy_wh_2_rect, get_axis_aligned_bbox, load_dataset, poly_iou - def parse_args(): """ args for fc testing. @@ -30,7 +29,8 @@ def parse_args(): parser = argparse.ArgumentParser(description='PyTorch SiamFC Tracking Test') parser.add_argument('--arch', default='SiamRPNRes22', type=str, help='backbone architecture') parser.add_argument('--resume', default='/data/zpzhang/project4/siamese/Siamese/snapshot/CIResNet22RPN.model', type=str, help='pretrained model') - parser.add_argument('--video', default='/data/zpzhang/project4/siamese/Siamese/videos/bag.mp4', type=str, help='video file path') + parser.add_argument('--video', default=None, help='video file path') + parser.add_argument('--images', default=None, help='images directory') parser.add_argument('--init_bbox', default=None, help='bbox in the first frame None or [lx, ly, w, h]') args = parser.parse_args() @@ -46,7 +46,7 @@ def track_video(tracker, model, video_path, init_box=None): cv2.namedWindow(display_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) cv2.resizeWindow(display_name, 960, 720) success, frame = cap.read() - cv2.imshow(display_name, frame) + #cv2.imshow(display_name, frame) if success is not True: print("Read failed.") @@ -121,6 +121,96 @@ def track_video(tracker, model, video_path, init_box=None): cv2.destroyAllWindows() + +def track_images(tracker, model, images_path, init_box=None): + + assert os.path.isdir(images_path), "please provide a valid folder name" + + display_name = 'Video: {}'.format(images_path.split('/')[-1]) + cv2.namedWindow(display_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) + cv2.resizeWindow(display_name, 960, 720) + + im_paths = [(images_path + '/' + f) for f in os.listdir(images_path) if '.jpg' in f] + if len(im_paths) == 0: + print("no jpg images found in dir") + exit(-1) + + frame = cv2.imread(im_paths[0]) + #cv2.imshow(im_paths[0].split('/')[-1], frame) + + + # init + if init_box is not None: + lx, ly, w, h = init_box + target_pos = np.array([lx + w/2, ly + h/2]) + target_sz = np.array([w, h]) + state = tracker.init(frame, target_pos, target_sz, model) # init tracker + + else: + while True: + + frame_disp = frame.copy() + + cv2.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL, + 1, (0, 0, 255), 1) + + lx, ly, w, h = cv2.selectROI(display_name, frame_disp, fromCenter=False) + target_pos = np.array([lx + w / 2, ly + h / 2]) + target_sz = np.array([w, h]) + state = tracker.init(frame_disp, target_pos, target_sz, model) # init tracker + + break + + path_idx = 0 + while path_idx < len(im_paths): + time.sleep(2) + path_idx += 1 + frame = cv2.imread(im_paths[path_idx]) + + if frame is None: + return + + frame_disp = frame.copy() + + # Draw box + state = tracker.track(state, frame_disp) # track + location = cxy_wh_2_rect(state['target_pos'], state['target_sz']) + x1, y1, x2, y2 = int(location[0]), int(location[1]), int(location[0] + location[2]), int(location[1] + location[3]) + + cv2.rectangle(frame_disp, (x1, y1), (x2, y2), (0, 255, 0), 5) + + font_color = (0, 0, 0) + cv2.putText(frame_disp, 'Tracking!', (20, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, + font_color, 1) + cv2.putText(frame_disp, 'Press r to reset', (20, 55), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, + font_color, 1) + cv2.putText(frame_disp, 'Press q to quit', (20, 80), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, + font_color, 1) + + # Display the resulting frame + cv2.imshow(display_name, frame_disp) + key = cv2.waitKey(1) + if key == ord('q'): + break + elif key == ord('r'): + path_idx += 1 + frame = cv2.imread(im_paths[path_idx]) + frame_disp = frame.copy() + + cv2.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL, + 1.5, + (0, 0, 0), 1) + + cv2.imshow(display_name, frame_disp) + lx, ly, w, h = cv2.selectROI(display_name, frame_disp, fromCenter=False) + target_pos = np.array([lx + w / 2, ly + h / 2]) + target_sz = np.array([w, h]) + state = tracker.init(frame_disp, target_pos, target_sz, model) + + # When everything done, release the capture + cv2.destroyAllWindows() + + def main(): args = parse_args() @@ -152,7 +242,12 @@ def main(): else: pass - track_video(tracker, net, args.video, init_box=args.init_bbox) + if args.video is not None: + track_video(tracker, net, args.video, init_box=args.init_bbox) + elif args.images is not None: + track_images(tracker, net, args.images, init_box=args.init_bbox) + else: + print('Please give path to image dir or video file') if __name__ == '__main__': main()