From 01439d1f2eab87e19f39cb2cb2a5b5c4afa3f6b6 Mon Sep 17 00:00:00 2001 From: sltlls <50494500+sltlls@users.noreply.github.com> Date: Wed, 22 Jun 2022 20:34:28 +0800 Subject: [PATCH] Fix inference issue with large image (#368) --- mmrotate/apis/inference.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/mmrotate/apis/inference.py b/mmrotate/apis/inference.py index fabaa8c0c..ce1cc5e65 100644 --- a/mmrotate/apis/inference.py +++ b/mmrotate/apis/inference.py @@ -49,19 +49,20 @@ def inference_detector_by_patches(model, sizes, steps = get_multiscale_patch(sizes, steps, ratios) windows = slide_window(width, height, sizes, steps) - # prepare patch data - patch_datas = [] - for window in windows: - data = dict(img=img, win=window.tolist()) - # build the data pipeline - data = test_pipeline(data) - patch_datas.append(data) - results = [] start = 0 while True: - data = patch_datas[start:start + bs] - data = collate(data, samples_per_gpu=len(data)) + # prepare patch data + patch_datas = [] + if (start + bs) > len(windows): + end = len(windows) + else: + end = start + bs + for window in windows[start:end]: + data = dict(img=img, win=window.tolist()) + data = test_pipeline(data) + patch_datas.append(data) + data = collate(patch_datas, samples_per_gpu=len(patch_datas)) # just get the actual data from DataContainer data['img_metas'] = [ img_metas.data[0] for img_metas in data['img_metas'] @@ -80,7 +81,7 @@ def inference_detector_by_patches(model, with torch.no_grad(): results.extend(model(return_loss=False, rescale=True, **data)) - if start + bs >= len(patch_datas): + if end >= len(windows): break start += bs