Skip to content

Commit

Permalink
[UnitTest] Detector unittest (open-mmlab#669)
Browse files Browse the repository at this point in the history
* resolve comments

* update changelog

* add unittest for detector

* resolve comments
  • Loading branch information
kennymckormick authored Mar 3, 2021
1 parent a5896a6 commit 59ad57f
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 4 deletions.
8 changes: 5 additions & 3 deletions tests/test_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .base import (check_norm_state, generate_backbone_demo_inputs,
generate_gradcam_inputs, generate_recognizer_demo_inputs,
get_audio_recognizer_cfg, get_cfg, get_localizer_cfg,
generate_detector_demo_inputs, generate_gradcam_inputs,
generate_recognizer_demo_inputs, get_audio_recognizer_cfg,
get_cfg, get_detector_cfg, get_localizer_cfg,
get_recognizer_cfg)

__all__ = [
'check_norm_state', 'generate_backbone_demo_inputs',
'generate_recognizer_demo_inputs', 'generate_gradcam_inputs', 'get_cfg',
'get_recognizer_cfg', 'get_audio_recognizer_cfg', 'get_localizer_cfg'
'get_recognizer_cfg', 'get_audio_recognizer_cfg', 'get_localizer_cfg',
'get_detector_cfg', 'generate_detector_demo_inputs'
]
51 changes: 50 additions & 1 deletion tests/test_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,50 @@ def generate_recognizer_demo_inputs(
return inputs


def generate_detector_demo_inputs(
input_shape=(1, 3, 4, 224, 224), num_classes=81, train=True,
device='cpu'):
num_samples = input_shape[0]
if not train:
assert num_samples == 1

def random_box(n):
box = torch.rand(n, 4) * 0.5
box[:, 2:] += 0.5
box[:, 0::2] *= input_shape[3]
box[:, 1::2] *= input_shape[4]
if device == 'cuda':
box = box.cuda()
return box

def random_label(n):
label = torch.randn(n, num_classes)
label = (label > 0.8).type(torch.float32)
label[:, 0] = 0
if device == 'cuda':
label = label.cuda()
return label

img = torch.FloatTensor(np.random.random(input_shape))
if device == 'cuda':
img = img.cuda()

proposals = [random_box(2) for i in range(num_samples)]
gt_bboxes = [random_box(2) for i in range(num_samples)]
gt_labels = [random_label(2) for i in range(num_samples)]
img_metas = [dict(img_shape=input_shape[-2:]) for i in range(num_samples)]

if train:
return dict(
img=img,
proposals=proposals,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
img_metas=img_metas)
else:
return dict(img=[img], proposals=[proposals], img_metas=[img_metas])


def generate_gradcam_inputs(input_shape=(1, 3, 3, 224, 224), model_type='2D'):
"""Create a superset of inputs needed to run gradcam.
Expand Down Expand Up @@ -89,7 +133,8 @@ def get_cfg(config_type, fname):
These are deep copied to allow for safe modification of parameters without
influencing other tests.
"""
config_types = ('recognition', 'recognition_audio', 'localization')
config_types = ('recognition', 'recognition_audio', 'localization',
'detection')
assert config_type in config_types

repo_dpath = osp.dirname(osp.dirname(osp.dirname(__file__)))
Expand All @@ -111,3 +156,7 @@ def get_audio_recognizer_cfg(fname):

def get_localizer_cfg(fname):
return get_cfg('localization', fname)


def get_detector_cfg(fname):
return get_cfg('detection', fname)
Empty file.
41 changes: 41 additions & 0 deletions tests/test_models/test_detectors/test_detectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch

from ..base import generate_detector_demo_inputs, get_detector_cfg

try:
from mmaction.models import build_detector
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
mmdet_imported = False


@pytest.mark.skipif(not mmdet_imported, reason='requires mmdet')
def test_ava_detector():
config = get_detector_cfg('ava/slowonly_kinetics_pretrained_r50_'
'4x16x1_20e_ava_rgb.py')
detector = build_detector(config.model)

if torch.__version__ == 'parrots':
if torch.cuda.is_available():
train_demo_inputs = generate_detector_demo_inputs(
train=True, device='cuda')
test_demo_inputs = generate_detector_demo_inputs(
train=False, device='cuda')
detector = detector.cuda()

losses = detector(**train_demo_inputs)
assert isinstance(losses, dict)

# Test forward test
with torch.no_grad():
_ = detector(**test_demo_inputs, return_loss=False)
else:
train_demo_inputs = generate_detector_demo_inputs(train=True)
test_demo_inputs = generate_detector_demo_inputs(train=False)
losses = detector(**train_demo_inputs)
assert isinstance(losses, dict)

# Test forward test
with torch.no_grad():
_ = detector(**test_demo_inputs, return_loss=False)

0 comments on commit 59ad57f

Please sign in to comment.