Skip to content

Commit

Permalink
added test dataset support
Browse files Browse the repository at this point in the history
  • Loading branch information
rostyslavhereha committed Feb 1, 2024
1 parent 613e70f commit dc7c991
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
6 changes: 4 additions & 2 deletions auto_training/config_factories/pvt_config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def make_pvt_cfg(data_path: str,
cfg.train_img_prefix = f"{cfg.data_root}/train/image_2"
cfg.val_ann_file = f"{cfg.data_root}/val/coco_val.json"
cfg.val_img_prefix = f"{cfg.data_root}/val/image_2/"
cfg.test_ann_file = f"{cfg.data_root}/test/coco_test.json"
cfg.test_img_prefix = f"{cfg.data_root}/test/image_2/"

min_res = (input_res[0], input_res[1] * 0.8)
training_classes = parse_training_data_classes(cfg.train_ann_file)
Expand Down Expand Up @@ -127,8 +129,8 @@ def make_pvt_cfg(data_path: str,
]),
test=dict(
type=cfg.dataset_type,
ann_file=cfg.val_ann_file,
img_prefix=cfg.val_img_prefix,
ann_file=cfg.test_ann_file,
img_prefix=cfg.test_img_prefix,
pipeline=[
dict(type='LoadImageFromFile'),
dict(
Expand Down
8 changes: 5 additions & 3 deletions auto_training/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,25 @@ def prepare_folder(image_path, coco_path, coco, mode):
json.dump(coco, fp)


def make_coco_folder(cocos, coco_path, train_image_path, val_image_path):
def make_coco_folder(cocos, coco_path, train_image_path, val_image_path, test_image_path):
prepare_folder(train_image_path, coco_path, cocos[0], mode="train")
prepare_folder(val_image_path, coco_path, cocos[1], mode="val")
prepare_folder(test_image_path, coco_path, cocos[2], mode="test")



def main():
args = parse_args()
target_class_map = json.loads(args.target_class_map)
cocos = convert_kitti_files(args.kitti_train, args.kitti_val, target_class_map)
make_coco_folder(cocos, args.coco_folder, args.kitti_train, args.kitti_val)
cocos = convert_kitti_files(args.kitti_train, args.kitti_val, args.kitti_test, target_class_map)
make_coco_folder(cocos, args.coco_folder, args.kitti_train, args.kitti_val, args.kitti_test)


def parse_args():
parser = argparse.ArgumentParser(description='Convert kitti to coco dataset')
parser.add_argument('kitti_train', type=str, help='train input data path, kitti dataset')
parser.add_argument('kitti_val', type=str, help='val input data path, kitti dataset')
parser.add_argument('kitti_test', type=str, help='test input data path, kitti dataset')
parser.add_argument('coco_folder', type=str, help='output folder')
parser.add_argument('--target-class-map', type=str, default="{}", help='target class mapping, json strin format. Map to None if class should not be used.')
return parser.parse_args()
Expand Down
4 changes: 2 additions & 2 deletions auto_training/utils/kitti_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ def read_files_make_dict(path):
path_dict = {os.path.basename(file): file for file in files}
return files, path_dict

def convert_kitti_files(train_path, val_path, target_class_map={}):
def convert_kitti_files(train_path, val_path, test_path, target_class_map={}):
image_ids = 0
category_ids = 0
annot_ids = 0
categories = {}
cocos = []
for path in train_path, val_path:
for path in train_path, val_path, test_path:
image_dict = {}
annot_dict = {}
files, path_dict = read_files_make_dict(path)
Expand Down
5 changes: 5 additions & 0 deletions custom/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ def __init__(

def run(self, threshold=0.5):
results = []
num_images = len(os.listdir(self.inf_dir))
counter = 0
for imn in os.listdir(self.inf_dir):
img = f"{self.inf_dir}/{imn}"
if counter % 100 == 0:
print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - mmdet - INFO - Predict({counter}/{num_images})")
self.fps_logger.start_record()
result = inference_detector(self.model, img)
results.append({img: extract_bounding_boxes(result, threshold)})
Expand All @@ -74,6 +78,7 @@ def run(self, threshold=0.5):
out_file=f"{self.inf_out_dir}/{imn}",
score_thr=threshold,
)
counter += 1
if self.output_file:
with open(self.output_file, "w") as f:
json.dump(results, f)
Expand Down

0 comments on commit dc7c991

Please sign in to comment.