From cf10c91ded43e6d30a3969e318726ab8d0c48f27 Mon Sep 17 00:00:00 2001 From: zhe chen Date: Sun, 5 Jun 2022 13:13:55 +0800 Subject: [PATCH] single image inference --- segmentation/image_demo.py | 58 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 segmentation/image_demo.py diff --git a/segmentation/image_demo.py b/segmentation/image_demo.py new file mode 100644 index 000000000..7df0cbbd4 --- /dev/null +++ b/segmentation/image_demo.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +import mmcv + +import mmcv_custom # noqa: F401,F403 +import mmseg_custom # noqa: F401,F403 +from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot +from mmseg.core.evaluation import get_palette +from mmcv.runner import load_checkpoint +from mmseg.core import get_classes +import cv2 +import os.path as osp + + +def main(): + parser = ArgumentParser() + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument('img', help='Image file') + parser.add_argument('--out', type=str, default="demo", help='out dir') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--palette', + default='cityscapes', + help='Color palette used for segmentation map') + parser.add_argument( + '--opacity', + type=float, + default=0.5, + help='Opacity of painted segmentation map. In (0, 1] range.') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + + model = init_segmentor(args.config, checkpoint=None, device=args.device) + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + if 'CLASSES' in checkpoint.get('meta', {}): + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + model.CLASSES = get_classes(args.palette) + + # test a single image + result = inference_segmentor(model, args.img) + # show the results + if hasattr(model, 'module'): + model = model.module + img = model.show_result(args.img, result, + palette=get_palette(args.palette), + show=False, opacity=args.opacity) + mmcv.mkdir_or_exist(args.out) + out_path = osp.join(args.out, osp.basename(args.img)) + cv2.imwrite(out_path, img) + print(f"Result is save at {out_path}") + +if __name__ == '__main__': + main() \ No newline at end of file