diff --git a/.gitignore b/.gitignore index c9b568f..ff1b06b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ *.pyc *.swp + +#folder +*/__pycache__ +data \ No newline at end of file diff --git a/README.md b/README.md index 53ac19d..4c26fd7 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +* Add *: +`Modify the code to make it run sucessfully on Windows10 (python3.5+,tenforflow,opencv3)` + ## YOLO_tensorflow Tensorflow implementation of [YOLO](https://arxiv.org/pdf/1506.02640.pdf), including training and test phase. diff --git a/test.py b/test.py index 0c791fb..420ff84 100644 --- a/test.py +++ b/test.py @@ -27,7 +27,7 @@ def __init__(self, net, weight_file): self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) - print 'Restoring weights from: ' + self.weights_file + print ('Restoring weights from: ' + self.weights_file) self.saver = tf.train.Saver() self.saver.restore(self.sess, self.weights_file) @@ -40,7 +40,7 @@ def draw_result(self, img, result): cv2.rectangle(img, (x - w, y - h), (x + w, y + h), (0, 255, 0), 2) cv2.rectangle(img, (x - w, y - h - 20), (x + w, y - h), (125, 125, 125), -1) - cv2.putText(img, result[i][0] + ' : %.2f' % result[i][5], (x - w + 5, y - h - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.CV_AA) + cv2.putText(img, result[i][0] + ' : %.2f' % result[i][5], (x - w + 5, y - h - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) def detect(self, img): img_h, img_w, _ = img.shape @@ -169,6 +169,7 @@ def main(): parser.add_argument('--weight_dir', default='weights', type=str) parser.add_argument('--data_dir', default="data", type=str) parser.add_argument('--gpu', default='', type=str) + parser.add_argument('--file_name',default="person.jpg",type=str) args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu @@ -182,7 +183,7 @@ def main(): # detector.camera_detector(cap) # detect from image file - imname = 'test/person.jpg' + imname = 'test/'+args.file_name detector.image_detector(imname) diff --git a/train.py b/train.py index 8777b60..d458da7 100644 --- a/train.py +++ b/train.py @@ -63,7 +63,7 @@ def train(self): train_timer = Timer() load_timer = Timer() - for step in xrange(1, self.max_iter + 1): + for step in range(1, self.max_iter + 1): load_timer.tic() images, labels = self.data.get() diff --git a/utils/pascal_voc.py b/utils/pascal_voc.py index ed105bc..de24cc4 100644 --- a/utils/pascal_voc.py +++ b/utils/pascal_voc.py @@ -2,7 +2,7 @@ import xml.etree.ElementTree as ET import numpy as np import cv2 -import cPickle +import pickle import copy import yolo.config as cfg @@ -16,7 +16,7 @@ def __init__(self, phase, rebuild=False): self.image_size = cfg.IMAGE_SIZE self.cell_size = cfg.CELL_SIZE self.classes = cfg.CLASSES - self.class_to_ind = dict(zip(self.classes, xrange(len(self.classes)))) + self.class_to_ind = dict(zip(self.classes, range(len(self.classes)))) self.flipped = cfg.FLIPPED self.phase = phase self.rebuild = rebuild @@ -59,8 +59,8 @@ def prepare(self): for idx in range(len(gt_labels_cp)): gt_labels_cp[idx]['flipped'] = True gt_labels_cp[idx]['label'] = gt_labels_cp[idx]['label'][:, ::-1, :] - for i in xrange(self.cell_size): - for j in xrange(self.cell_size): + for i in range(self.cell_size): + for j in range(self.cell_size): if gt_labels_cp[idx]['label'][i, j, 0] == 1: gt_labels_cp[idx]['label'][i, j, 1] = self.image_size - 1 - gt_labels_cp[idx]['label'][i, j, 1] gt_labels += gt_labels_cp @@ -74,7 +74,7 @@ def load_labels(self): if os.path.isfile(cache_file) and not self.rebuild: print('Loading gt_labels from: ' + cache_file) with open(cache_file, 'rb') as f: - gt_labels = cPickle.load(f) + gt_labels = pickle.load(f) return gt_labels print('Processing gt_labels from: ' + self.data_path) @@ -100,7 +100,7 @@ def load_labels(self): gt_labels.append({'imname': imname, 'label': label, 'flipped': False}) print('Saving gt_labels to: ' + cache_file) with open(cache_file, 'wb') as f: - cPickle.dump(gt_labels, f) + pickle.dump(gt_labels, f) return gt_labels def load_pascal_annotation(self, index): diff --git a/yolo/config.py b/yolo/config.py index 98fd7ff..c348654 100644 --- a/yolo/config.py +++ b/yolo/config.py @@ -14,8 +14,8 @@ WEIGHTS_DIR = os.path.join(PASCAL_PATH, 'weight') -WEIGHTS_FILE = None -# WEIGHTS_FILE = os.path.join(DATA_PATH, 'weights', 'YOLO_small.ckpt') +#WEIGHTS_FILE = None +WEIGHTS_FILE = os.path.join(DATA_PATH, 'weights', 'YOLO_small.ckpt') CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', @@ -59,7 +59,7 @@ STAIRCASE = True -BATCH_SIZE = 45 +BATCH_SIZE = 4 MAX_ITER = 15000