forked from danielgordon10/re3-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathre3_tracker.py
261 lines (218 loc) · 11.7 KB
/
re3_tracker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import cv2
import glob
import numpy as np
import os
import tensorflow as tf
import time
import sys
import os.path
sys.path.append(os.path.abspath(os.path.join(
os.path.dirname(__file__), os.path.pardir)))
from tracker import network
from re3_utils.util import bb_util
from re3_utils.util import im_util
from re3_utils.tensorflow_util import tf_util
# Network Constants
from constants import CROP_SIZE
from constants import CROP_PAD
from constants import LSTM_SIZE
from constants import LOG_DIR
from constants import GPU_ID
from constants import MAX_TRACK_LENGTH
SPEED_OUTPUT = True
class Re3Tracker(object):
def __init__(self, gpu_id=GPU_ID):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
basedir = os.path.dirname(__file__)
tf.Graph().as_default()
self.imagePlaceholder = tf.placeholder(tf.uint8, shape=(None, CROP_SIZE, CROP_SIZE, 3))
self.prevLstmState = tuple([tf.placeholder(tf.float32, shape=(None, LSTM_SIZE)) for _ in range(4)])
self.batch_size = tf.placeholder(tf.int32, shape=())
self.outputs, self.state1, self.state2 = network.inference(
self.imagePlaceholder, num_unrolls=1, batch_size=self.batch_size, train=False,
prevLstmState=self.prevLstmState)
self.sess = tf_util.Session()
self.sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(os.path.join(basedir, '..', LOG_DIR, 'checkpoints'))
if ckpt is None:
raise IOError(
('Checkpoint model could not be found. '
'Did you download the pretrained weights? '
'Download them here: http://bit.ly/2L5deYF and read the Model section of the Readme.'))
tf_util.restore(self.sess, ckpt.model_checkpoint_path)
self.tracked_data = {}
self.time = 0
self.total_forward_count = -1
# unique_id{str}: A unique id for the object being tracked.
# image{str or numpy array}: The current image or the path to the current image.
# starting_box{None or 4x1 numpy array or list}: 4x1 bounding box in X1, Y1, X2, Y2 format.
def track(self, unique_id, image, starting_box=None):
start_time = time.time()
if type(image) == str:
image = cv2.imread(image)[:,:,::-1]
else:
image = image.copy()
image_read_time = time.time() - start_time
if starting_box is not None:
lstmState = [np.zeros((1, LSTM_SIZE)) for _ in range(4)]
pastBBox = np.array(starting_box) # turns list into numpy array if not and copies for safety.
prevImage = image
originalFeatures = None
forwardCount = 0
elif unique_id in self.tracked_data:
lstmState, pastBBox, prevImage, originalFeatures, forwardCount = self.tracked_data[unique_id]
else:
raise Exception('Unique_id %s with no initial bounding box' % unique_id)
croppedInput0, pastBBoxPadded = im_util.get_cropped_input(prevImage, pastBBox, CROP_PAD, CROP_SIZE)
croppedInput1,_ = im_util.get_cropped_input(image, pastBBox, CROP_PAD, CROP_SIZE)
feed_dict = {
self.imagePlaceholder : [croppedInput0, croppedInput1],
self.prevLstmState : lstmState,
self.batch_size : 1,
}
rawOutput, s1, s2 = self.sess.run([self.outputs, self.state1, self.state2], feed_dict=feed_dict)
lstmState = [s1[0], s1[1], s2[0], s2[1]]
if forwardCount == 0:
originalFeatures = [s1[0], s1[1], s2[0], s2[1]]
prevImage = image
# Shift output box to full image coordinate system.
outputBox = bb_util.from_crop_coordinate_system(rawOutput.squeeze() / 10.0, pastBBoxPadded, 1, 1)
if forwardCount > 0 and forwardCount % MAX_TRACK_LENGTH == 0:
croppedInput, _ = im_util.get_cropped_input(image, outputBox, CROP_PAD, CROP_SIZE)
input = np.tile(croppedInput[np.newaxis,...], (2,1,1,1))
feed_dict = {
self.imagePlaceholder : input,
self.prevLstmState : originalFeatures,
self.batch_size : 1,
}
rawOutput, s1, s2 = self.sess.run([self.outputs, self.state1, self.state2], feed_dict=feed_dict)
lstmState = [s1[0], s1[1], s2[0], s2[1]]
forwardCount += 1
self.total_forward_count += 1
if starting_box is not None:
# Use label if it's given
outputBox = np.array(starting_box)
self.tracked_data[unique_id] = (lstmState, outputBox, image, originalFeatures, forwardCount)
end_time = time.time()
if self.total_forward_count > 0:
self.time += (end_time - start_time - image_read_time)
if SPEED_OUTPUT and self.total_forward_count % 100 == 0:
print('Current tracking speed: %.3f FPS' % (1 / (end_time - start_time - image_read_time)))
print('Current image read speed: %.3f FPS' % (1 / (image_read_time)))
print('Mean tracking speed: %.3f FPS\n' % (self.total_forward_count / max(.00001, self.time)))
return outputBox
# unique_ids{list{string}}: A list of unique ids for the objects being tracked.
# image{str or numpy array}: The current image or the path to the current image.
# starting_boxes{None or dictionary of unique_id to 4x1 numpy array or list}: unique_ids to starting box.
# Starting boxes only need to be provided if it is a new track. Bounding boxes in X1, Y1, X2, Y2 format.
def multi_track(self, unique_ids, image, starting_boxes=None):
start_time = time.time()
assert type(unique_ids) == list, 'unique_ids must be a list for multi_track'
assert len(unique_ids) > 1, 'unique_ids must be at least 2 elements'
if type(image) == str:
image = cv2.imread(image)[:,:,::-1]
else:
image = image.copy()
image_read_time = time.time() - start_time
# Get inputs for each track.
images = []
lstmStates = [[] for _ in range(4)]
pastBBoxesPadded = []
if starting_boxes is None:
starting_boxes = dict()
for unique_id in unique_ids:
if unique_id in starting_boxes:
lstmState = [np.zeros((1, LSTM_SIZE)) for _ in range(4)]
pastBBox = np.array(starting_boxes[unique_id]) # turns list into numpy array if not and copies for safety.
prevImage = image
originalFeatures = None
forwardCount = 0
self.tracked_data[unique_id] = (lstmState, pastBBox, image, originalFeatures, forwardCount)
elif unique_id in self.tracked_data:
lstmState, pastBBox, prevImage, originalFeatures, forwardCount = self.tracked_data[unique_id]
else:
raise Exception('Unique_id %s with no initial bounding box' % unique_id)
croppedInput0, pastBBoxPadded = im_util.get_cropped_input(prevImage, pastBBox, CROP_PAD, CROP_SIZE)
croppedInput1,_ = im_util.get_cropped_input(image, pastBBox, CROP_PAD, CROP_SIZE)
pastBBoxesPadded.append(pastBBoxPadded)
images.extend([croppedInput0, croppedInput1])
for ss,state in enumerate(lstmState):
lstmStates[ss].append(state.squeeze())
lstmStateArrays = []
for state in lstmStates:
lstmStateArrays.append(np.array(state))
feed_dict = {
self.imagePlaceholder : images,
self.prevLstmState : lstmStateArrays,
self.batch_size : len(images) / 2
}
rawOutput, s1, s2 = self.sess.run([self.outputs, self.state1, self.state2], feed_dict=feed_dict)
outputBoxes = np.zeros((len(unique_ids), 4))
for uu,unique_id in enumerate(unique_ids):
lstmState, pastBBox, prevImage, originalFeatures, forwardCount = self.tracked_data[unique_id]
lstmState = [s1[0][[uu],:], s1[1][[uu],:], s2[0][[uu],:], s2[1][[uu],:]]
if forwardCount == 0:
originalFeatures = [s1[0][[uu],:], s1[1][[uu],:], s2[0][[uu],:], s2[1][[uu],:]]
prevImage = image
# Shift output box to full image coordinate system.
pastBBoxPadded = pastBBoxesPadded[uu]
outputBox = bb_util.from_crop_coordinate_system(rawOutput[uu,:].squeeze() / 10.0, pastBBoxPadded, 1, 1)
if forwardCount > 0 and forwardCount % MAX_TRACK_LENGTH == 0:
croppedInput, _ = im_util.get_cropped_input(image, outputBox, CROP_PAD, CROP_SIZE)
input = np.tile(croppedInput[np.newaxis,...], (2,1,1,1))
feed_dict = {
self.imagePlaceholder : input,
self.prevLstmState : originalFeatures,
self.batch_size : 1,
}
_, s1_new, s2_new = self.sess.run([self.outputs, self.state1, self.state2], feed_dict=feed_dict)
lstmState = [s1_new[0], s1_new[1], s2_new[0], s2_new[1]]
forwardCount += 1
self.total_forward_count += 1
if unique_id in starting_boxes:
# Use label if it's given
outputBox = np.array(starting_boxes[unique_id])
outputBoxes[uu,:] = outputBox
self.tracked_data[unique_id] = (lstmState, outputBox, image, originalFeatures, forwardCount)
end_time = time.time()
if self.total_forward_count > 0:
self.time += (end_time - start_time - image_read_time)
if SPEED_OUTPUT and self.total_forward_count % 100 == 0:
print('Current tracking speed per object: %.3f FPS' % (len(unique_ids) / (end_time - start_time - image_read_time)))
print('Current tracking speed per frame: %.3f FPS' % (1 / (end_time - start_time - image_read_time)))
print('Current image read speed: %.3f FPS' % (1 / (image_read_time)))
print('Mean tracking speed per object: %.3f FPS\n' % (self.total_forward_count / max(.00001, self.time)))
return outputBoxes
class CopiedRe3Tracker(Re3Tracker):
def __init__(self, sess, copy_vars, gpu=None):
self.sess = sess
self.imagePlaceholder = tf.placeholder(tf.uint8, shape=(None, CROP_SIZE, CROP_SIZE, 3))
self.prevLstmState = tuple([tf.placeholder(tf.float32, shape=(None, LSTM_SIZE)) for _ in range(4)])
self.batch_size = tf.placeholder(tf.int32, shape=())
network_scope = 'test_network'
if gpu is not None:
with tf.device('/gpu:' + str(gpu)):
with tf.variable_scope(network_scope):
self.outputs, self.state1, self.state2 = network.inference(
self.imagePlaceholder, num_unrolls=1, batch_size=self.batch_size, train=False,
prevLstmState=self.prevLstmState)
else:
with tf.variable_scope(network_scope):
self.outputs, self.state1, self.state2 = network.inference(
self.imagePlaceholder, num_unrolls=1, batch_size=self.batch_size, train=False,
prevLstmState=self.prevLstmState)
local_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=network_scope)
self.sync_op = self.sync_from(copy_vars, local_vars)
self.tracked_data = {}
self.time = 0
self.total_forward_count = -1
def reset(self):
self.tracked_data = {}
self.sess.run(self.sync_op)
def sync_from(self, src_vars, dst_vars):
sync_ops = []
with tf.name_scope('Sync'):
for(src_var, dst_var) in zip(src_vars, dst_vars):
sync_op = tf.assign(dst_var, src_var)
sync_ops.append(sync_op)
return tf.group(*sync_ops)