Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test on series of images #113

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions lib/tracker/siamrpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ def init(self, im, target_pos, target_sz, model, hp=None):
p.renew()

# for vot17 or vot18: from siamrpn released
if '2017' in self.info.dataset:
if ((target_sz[0] * target_sz[1]) / float(state['im_h'] * state['im_w'])) < 0.004:
p.instance_size = 287
p.renew()
else:
p.instance_size = 271
p.renew()
if self.info.dataset:
if '2017' in self.info.dataset:
if ((target_sz[0] * target_sz[1]) / float(state['im_h'] * state['im_w'])) < 0.004:
p.instance_size = 287
p.renew()
else:
p.instance_size = 271
p.renew()

# param tune
if hp:
Expand Down
6 changes: 6 additions & 0 deletions lib/tutorials/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ python siamese_tracking/run_video.py --arch SiamRPNRes22 --resume snapshot/CIRes
- The opencv version here is 4.1.0.25, and older versions may be not friendly to some functions.
- If you try to conduct this project on a specific tracking task, eg. pedestrian tracking, it's suggested that you can tuning hyper-parameters on your collected data with our tuning toolkit detailed below.

## Test on images dir

```python
python siamese_tracking/run_video.py --arch SiamRPNRes22 --resume snapshot/CIResNet22_RPN.pth --images path/to/jpg/images
```

## Test through webcam
eg,
```
Expand Down
14 changes: 10 additions & 4 deletions lib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,25 @@ def remove_prefix(state_dict, prefix):
return {f(key): value for key, value in state_dict.items()}


def load_pretrain(model, pretrained_path):
def load_pretrain(model, pretrained_path, to_cpu=False):
print('load pretrained model from {}'.format(pretrained_path))

device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
if not to_cpu:
device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
else:
pretrained_dict = torch.load(pretrained_path, map_location='cpu')

if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model

if not to_cpu:
return model
return model.to('cpu')


Corner = namedtuple('Corner', 'x1 y1 x2 y2')
Expand Down
105 changes: 100 additions & 5 deletions siamese_tracking/run_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Email: [email protected]
# Detail: test siamese on a specific video (provide init bbox and video file)
# ------------------------------------------------------------------------------

import time
import _init_paths
import os
import cv2
Expand All @@ -22,15 +22,15 @@
from easydict import EasyDict as edict
from utils.utils import load_pretrain, cxy_wh_2_rect, get_axis_aligned_bbox, load_dataset, poly_iou


def parse_args():
"""
args for fc testing.
"""
parser = argparse.ArgumentParser(description='PyTorch SiamFC Tracking Test')
parser.add_argument('--arch', default='SiamRPNRes22', type=str, help='backbone architecture')
parser.add_argument('--resume', default='/data/zpzhang/project4/siamese/Siamese/snapshot/CIResNet22RPN.model', type=str, help='pretrained model')
parser.add_argument('--video', default='/data/zpzhang/project4/siamese/Siamese/videos/bag.mp4', type=str, help='video file path')
parser.add_argument('--video', default=None, help='video file path')
parser.add_argument('--images', default=None, help='images directory')
parser.add_argument('--init_bbox', default=None, help='bbox in the first frame None or [lx, ly, w, h]')
args = parser.parse_args()

Expand All @@ -46,7 +46,7 @@ def track_video(tracker, model, video_path, init_box=None):
cv2.namedWindow(display_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
cv2.resizeWindow(display_name, 960, 720)
success, frame = cap.read()
cv2.imshow(display_name, frame)
#cv2.imshow(display_name, frame)

if success is not True:
print("Read failed.")
Expand Down Expand Up @@ -121,6 +121,96 @@ def track_video(tracker, model, video_path, init_box=None):
cv2.destroyAllWindows()



def track_images(tracker, model, images_path, init_box=None):

assert os.path.isdir(images_path), "please provide a valid folder name"

display_name = 'Video: {}'.format(images_path.split('/')[-1])
cv2.namedWindow(display_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
cv2.resizeWindow(display_name, 960, 720)

im_paths = [(images_path + '/' + f) for f in os.listdir(images_path) if '.jpg' in f]
if len(im_paths) == 0:
print("no jpg images found in dir")
exit(-1)

frame = cv2.imread(im_paths[0])
#cv2.imshow(im_paths[0].split('/')[-1], frame)


# init
if init_box is not None:
lx, ly, w, h = init_box
target_pos = np.array([lx + w/2, ly + h/2])
target_sz = np.array([w, h])
state = tracker.init(frame, target_pos, target_sz, model) # init tracker

else:
while True:

frame_disp = frame.copy()

cv2.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL,
1, (0, 0, 255), 1)

lx, ly, w, h = cv2.selectROI(display_name, frame_disp, fromCenter=False)
target_pos = np.array([lx + w / 2, ly + h / 2])
target_sz = np.array([w, h])
state = tracker.init(frame_disp, target_pos, target_sz, model) # init tracker

break

path_idx = 0
while path_idx < len(im_paths):
time.sleep(2)
path_idx += 1
frame = cv2.imread(im_paths[path_idx])

if frame is None:
return

frame_disp = frame.copy()

# Draw box
state = tracker.track(state, frame_disp) # track
location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
x1, y1, x2, y2 = int(location[0]), int(location[1]), int(location[0] + location[2]), int(location[1] + location[3])

cv2.rectangle(frame_disp, (x1, y1), (x2, y2), (0, 255, 0), 5)

font_color = (0, 0, 0)
cv2.putText(frame_disp, 'Tracking!', (20, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1,
font_color, 1)
cv2.putText(frame_disp, 'Press r to reset', (20, 55), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1,
font_color, 1)
cv2.putText(frame_disp, 'Press q to quit', (20, 80), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1,
font_color, 1)

# Display the resulting frame
cv2.imshow(display_name, frame_disp)
key = cv2.waitKey(1)
if key == ord('q'):
break
elif key == ord('r'):
path_idx += 1
frame = cv2.imread(im_paths[path_idx])
frame_disp = frame.copy()

cv2.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL,
1.5,
(0, 0, 0), 1)

cv2.imshow(display_name, frame_disp)
lx, ly, w, h = cv2.selectROI(display_name, frame_disp, fromCenter=False)
target_pos = np.array([lx + w / 2, ly + h / 2])
target_sz = np.array([w, h])
state = tracker.init(frame_disp, target_pos, target_sz, model)

# When everything done, release the capture
cv2.destroyAllWindows()


def main():
args = parse_args()

Expand Down Expand Up @@ -152,7 +242,12 @@ def main():
else:
pass

track_video(tracker, net, args.video, init_box=args.init_bbox)
if args.video is not None:
track_video(tracker, net, args.video, init_box=args.init_bbox)
elif args.images is not None:
track_images(tracker, net, args.images, init_box=args.init_bbox)
else:
print('Please give path to image dir or video file')

if __name__ == '__main__':
main()