diff --git a/README.md b/README.md index 92c5604..6d8333c 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,11 @@ Step by step: 1. Open the `main/` directory 2. Insert the input images and videos in the folder **input/** - 3. Insert the classes in the file **class_list.txt** (one class name per line) + 3. Edit config.in setting the full filepath to your class_list (WITHOUT quotation marks). E.g. + [CLASSES] + MOST_RECENT_FILE = /home/user1/Desktop/class_list.txt + + Alternatively, just edit the contents of the example file ./main/class_list.txt 4. Run the code: 5. You can find the annotations in the folder **output/** diff --git a/main/__init__.py b/main/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/main/config.ini b/main/config.ini index fb305f9..a9b54ce 100644 --- a/main/config.ini +++ b/main/config.ini @@ -2,3 +2,6 @@ OBJECT_SCORE_THRESHOLD = 0.65 OBJECT_IDS = 1,2 CUDA_VISIBLE_DEVICES = '' + +[CLASSES] +MOST_RECENT_FILE = diff --git a/main/load_classes.py b/main/load_classes.py new file mode 100644 index 0000000..b66536a --- /dev/null +++ b/main/load_classes.py @@ -0,0 +1,32 @@ +import configparser + +from pathlib import Path + + +def non_blank_lines(file_object): + for l in file_object: + line = l.rstrip() + if line: + yield line + + +def get_class_list(): + """ + Uses the most recent classes source file defined in config.ini + otherwise defaults the example class_list.txt file provided in + this repository. + + """ + config = configparser.ConfigParser() + this_dir = Path(__file__).parent + config_path = this_dir / "config.ini" + config.read(str(config_path)) + most_recent_classes_file = config.get("CLASSES", "MOST_RECENT_FILE") + + if not most_recent_classes_file: + classes_src = str(this_dir / "class_list.txt") + else: + classes_src = most_recent_classes_file + + with open(classes_src) as f: + return list(non_blank_lines(file_object=f)) diff --git a/main/main.py b/main/main.py index 0b76409..05da04c 100755 --- a/main/main.py +++ b/main/main.py @@ -1,6 +1,5 @@ #!/bin/python import argparse -import glob import json import os import re @@ -12,7 +11,10 @@ from lxml import etree import xml.etree.cElementTree as ET +from load_classes import get_class_list + +CLASS_LIST = get_class_list() DELAY = 20 # keyboard delay (in milliseconds) WITH_QT = False try: @@ -718,13 +720,6 @@ def convert_video_to_images(video_path, n_frames, desired_img_format): return file_path, video_name_ext -def nonblank_lines(f): - for l in f: - line = l.rstrip() - if line: - yield line - - def get_annotation_paths(img_path, annotation_formats): annotation_paths = [] for ann_dir, ann_ext in annotation_formats.items(): @@ -1032,10 +1027,6 @@ def complement_bgr(color): elif '.xml' in ann_path: create_PASCAL_VOC_xml(ann_path, abs_path, folder_name, image_name, img_height, img_width, depth) - # load class list - with open('class_list.txt') as f: - CLASS_LIST = list(nonblank_lines(f)) - #print(CLASS_LIST) last_class_index = len(CLASS_LIST) - 1 # Make the class colors the same each session diff --git a/main/main_auto.py b/main/main_auto.py index b32e09f..b8be738 100755 --- a/main/main_auto.py +++ b/main/main_auto.py @@ -26,6 +26,12 @@ from lxml import etree import xml.etree.cElementTree as ET import sys + +from load_classes import get_class_list + + +CLASS_LIST = get_class_list() + sys.path.insert(0, "..") from object_detection.tf_object_detection import ObjectDetector import configparser @@ -313,13 +319,6 @@ def convert_video_to_images(video_path, n_frames, desired_img_format): return file_path, video_name_ext -def nonblank_lines(f): - for l in f: - line = l.rstrip() - if line: - yield line - - def get_annotation_paths(img_path, annotation_formats): annotation_paths = [] for ann_dir, ann_ext in annotation_formats.items(): @@ -695,10 +694,6 @@ def predict_next_frames(self,json_file_data,json_file_path): elif '.xml' in ann_path: create_PASCAL_VOC_xml(ann_path, abs_path, folder_name, image_name, img_height, img_width, depth) -# load class list -with open('class_list.txt') as f: - CLASS_LIST = list(nonblank_lines(f)) -#print(CLASS_LIST) last_class_index = len(CLASS_LIST) - 1 # Make the class colors the same each session diff --git a/requirements.txt b/requirements.txt index db49ed2..8286b81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -lxml==4.3.0 -numpy==1.16.0 +lxml==4.6.3 +numpy==1.21.2 opencv-contrib-python==3.4.9.33 tqdm==4.29.1 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_load_classes.py b/tests/test_load_classes.py new file mode 100644 index 0000000..9df76d7 --- /dev/null +++ b/tests/test_load_classes.py @@ -0,0 +1,6 @@ +from main.load_classes import get_class_list + + +def test_get_class_list(): + classes = get_class_list() + assert classes == ['person', 'billiard ball', 'donut']