Skip to content

Commit

Permalink
Merge pull request #15 from mousyball/network
Browse files Browse the repository at this point in the history
[test] Test for config, network and loss builder
  • Loading branch information
JanLin0817 authored Feb 2, 2021
2 parents d9837a9 + a25be9d commit f09ebf7
Show file tree
Hide file tree
Showing 19 changed files with 181 additions and 19 deletions.
14 changes: 14 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[run]
omit =

[report]
exclude_lines =
pragma: no cover
def __repr__
if self.debug:
if settings.DEBUG
raise AssertionError
raise NotImplementedError
if 0:
if __name__ == .__main__.:
assert False
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ repos:
- id: check-merge-conflict
- id: check-json
- id: check-yaml
args:
- "--unsafe"
- id: trailing-whitespace
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion networks/classification/backbones/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def forward(self, x):
class BaseBackbone(IBackbone):
def __init__(self):
"""[NOTE] Define your network in submodule."""
super(IBackbone, self).__init__()
super(BaseBackbone, self).__init__()

def init_weights(self):
"""Initialize the weights in your network.
Expand Down
3 changes: 2 additions & 1 deletion pytorch_trainer/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .config import get_cfg_defaults
from .config import get_cfg_defaults, parse_yaml_config
from .builder import build
from .registry import Registry

__all__ = [
'Registry',
'get_cfg_defaults',
'parse_yaml_config',
'build'
]
16 changes: 5 additions & 11 deletions pytorch_trainer/utils/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,16 @@ def build(cfg, registry):
Returns:
nn.Module: A built nn module.
"""

_cfg = deepcopy(cfg)
obj_name = _cfg.get('NAME')

if isinstance(obj_name, str):
obj_cls = registry.get(obj_name)
if obj_cls is None:
raise KeyError(
f'{obj_name} is not in the {registry._name} registry')
else:
raise TypeError(
f'type must be a str, but got {type(obj_name)}')
obj_name = _cfg.pop('NAME')
assert isinstance(obj_name, str)

# [NOTE] 'KeyError' is handled in registry.
obj_cls = registry.get(obj_name)

# [Case]: LOSS
if registry._name == 'loss':
_cfg.pop('NAME')
return obj_cls(**dict(_cfg))

return obj_cls(_cfg)
4 changes: 2 additions & 2 deletions pytorch_trainer/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def get_cfg_defaults():
return _C.clone()


def parse_yaml_config(config_path):
def parse_yaml_config(config_path, allow_unsafe=False):
cfg = get_cfg_defaults()
cfg.merge_from_file(config_path)
cfg.merge_from_file(config_path, allow_unsafe)
cfg.freeze()
return cfg
2 changes: 0 additions & 2 deletions requirements.txt

This file was deleted.

6 changes: 6 additions & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
torch>=1.7.1
fvcore
autopep8
pre-commit
pytest
pytest-cov
3 changes: 1 addition & 2 deletions requirements-dev.txt → requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
torch>=1.7.1
fvcore
autopep8
pre-commit
tensorboard
2 changes: 2 additions & 0 deletions tests/components/configs/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
KEY1: "base"
KEY2: "base"
3 changes: 3 additions & 0 deletions tests/components/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_BASE_: "base.yaml"
KEY2: "config"
EXPRESSION: !!python/object/apply:eval ["[x ** 2 for x in [1, 2, 3]]"]
23 changes: 23 additions & 0 deletions tests/components/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os.path as osp

from pytorch_trainer.utils.config import parse_yaml_config

ROOT_PATH = './test/components/'


class TestConfig:
def test_yaml_parser(self):
# [Case] base.yaml
PATH = './configs/base.yaml'
cfg_base = parse_yaml_config(osp.join(ROOT_PATH, PATH))

assert cfg_base.KEY1 == 'base'
assert cfg_base.KEY2 == 'base'

# [Case] config.yaml inherits from base.yaml
PATH = './configs/config.yaml'
cfg = parse_yaml_config(osp.join(ROOT_PATH, PATH), allow_unsafe=True)

assert cfg.KEY1 == 'base'
assert cfg.KEY2 == 'config'
assert cfg.EXPRESSION == [1, 4, 9]
10 changes: 10 additions & 0 deletions tests/components/test_loss_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from networks.loss import LOSSES
from networks.loss.regular import torch_loss


class TestLoss:
def test_loss_list(self):
print(LOSSES._obj_map.keys())
loss_registry = set(LOSSES._obj_map.keys())
loss_table = set(torch_loss.keys())
assert loss_table.issubset(loss_registry)
110 changes: 110 additions & 0 deletions tests/components/test_network_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os.path as osp

import torch
import pytest
import torch.nn as nn

from pytorch_trainer.utils.config import parse_yaml_config
from networks.classification.builder import build_network

ROOT_PATH = './configs/networks/classification/'


class TestClassification:
@pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml'])
def test_network_builder_with_cfg(self, filename):
"""Network tests include backbone/loss builder
"""
FILENAME = filename
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME))
net = build_network(cfg)
assert isinstance(net, nn.Module)

def test_network_builder_with_keyerror(self):
FILENAME = 'lenet.yaml'
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME))
# Remove keyword: 'NETWORK'
cfg.pop('NETWORK')
with pytest.raises(KeyError) as excinfo:
_ = build_network(cfg)
assert "KeyError" in str(excinfo)

@pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml'])
def test_network_forward(self, filename):
FILENAME = filename
# Construct network
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME))
net = build_network(cfg)
net.eval()
# Initialize input
if cfg.get('NETWORK'):
n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS')
elif cfg.get('CUSTOM'):
n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS')
else:
assert False
x = torch.rand(4, 3, 32, 32)
# Inference
output_size = net(x).shape
assert output_size == torch.Size([4, n_class])

@pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml'])
def test_network_train_step(self, filename):
FILENAME = filename
# Construct network
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME))
net = build_network(cfg)
net.train()
# Initialize input
if cfg.get('NETWORK'):
n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS')
elif cfg.get('CUSTOM'):
n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS')
else:
assert False
x = torch.rand(4, 3, 32, 32)
y = torch.randint(low=0, high=n_class, size=(4,))
# Training Step
output_loss = net.train_step((x, y))
assert 'loss' in output_loss
assert 'multi_loss' in output_loss

@pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml'])
def test_network_val_step(self, filename):
FILENAME = filename
# Construct network
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME))
net = build_network(cfg)
net.eval()
# Initialize input
if cfg.get('NETWORK'):
n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS')
elif cfg.get('CUSTOM'):
n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS')
else:
assert False
x = torch.rand(4, 3, 32, 32)
y = torch.randint(low=0, high=n_class, size=(4,))
# Training Step
output_loss = net.val_step((x, y))
assert 'loss' in output_loss
assert 'multi_loss' in output_loss

def test_backbone_builder_with_keyerror(self):
FILENAME = 'lenet.yaml'
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME))
# Remove keyword: 'BACKBONE'
cfg.NETWORK.pop('BACKBONE')
with pytest.raises(KeyError) as excinfo:
_ = build_network(cfg)
assert "KeyError" in str(excinfo)

def test_utils_builder_with_keyerror(self):
FILENAME = 'lenet.yaml'
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME))
# Modify keyword name
cfg.defrost()
cfg.NETWORK.BACKBONE.NAME = 'WrongLeNet'
with pytest.raises(KeyError) as excinfo:
_ = build_network(cfg)
assert "KeyError" in str(excinfo)
File renamed without changes.
File renamed without changes.

0 comments on commit f09ebf7

Please sign in to comment.