diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..e9d660b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,14 @@ +[run] +omit = + +[report] +exclude_lines = + pragma: no cover + def __repr__ + if self.debug: + if settings.DEBUG + raise AssertionError + raise NotImplementedError + if 0: + if __name__ == .__main__.: + assert False \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71f8d12..4fc7f46 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,8 @@ repos: - id: check-merge-conflict - id: check-json - id: check-yaml + args: + - "--unsafe" - id: trailing-whitespace - repo: https://gitlab.com/pycqa/flake8 rev: 3.8.4 diff --git a/example/build_lenet.py b/examples/build_lenet.py similarity index 100% rename from example/build_lenet.py rename to examples/build_lenet.py diff --git a/example/train_cifar10.py b/examples/train_cifar10.py similarity index 100% rename from example/train_cifar10.py rename to examples/train_cifar10.py diff --git a/example/train_cifar10_by_iter.py b/examples/train_cifar10_by_iter.py similarity index 100% rename from example/train_cifar10_by_iter.py rename to examples/train_cifar10_by_iter.py diff --git a/networks/classification/backbones/base.py b/networks/classification/backbones/base.py index 95c9376..3cf4fe3 100644 --- a/networks/classification/backbones/base.py +++ b/networks/classification/backbones/base.py @@ -17,7 +17,7 @@ def forward(self, x): class BaseBackbone(IBackbone): def __init__(self): """[NOTE] Define your network in submodule.""" - super(IBackbone, self).__init__() + super(BaseBackbone, self).__init__() def init_weights(self): """Initialize the weights in your network. diff --git a/pytorch_trainer/utils/__init__.py b/pytorch_trainer/utils/__init__.py index 00fdd9e..d6cde43 100644 --- a/pytorch_trainer/utils/__init__.py +++ b/pytorch_trainer/utils/__init__.py @@ -1,9 +1,10 @@ -from .config import get_cfg_defaults +from .config import get_cfg_defaults, parse_yaml_config from .builder import build from .registry import Registry __all__ = [ 'Registry', 'get_cfg_defaults', + 'parse_yaml_config', 'build' ] diff --git a/pytorch_trainer/utils/builder.py b/pytorch_trainer/utils/builder.py index 1351f5e..841a647 100644 --- a/pytorch_trainer/utils/builder.py +++ b/pytorch_trainer/utils/builder.py @@ -11,22 +11,16 @@ def build(cfg, registry): Returns: nn.Module: A built nn module. """ - _cfg = deepcopy(cfg) - obj_name = _cfg.get('NAME') - if isinstance(obj_name, str): - obj_cls = registry.get(obj_name) - if obj_cls is None: - raise KeyError( - f'{obj_name} is not in the {registry._name} registry') - else: - raise TypeError( - f'type must be a str, but got {type(obj_name)}') + obj_name = _cfg.pop('NAME') + assert isinstance(obj_name, str) + + # [NOTE] 'KeyError' is handled in registry. + obj_cls = registry.get(obj_name) # [Case]: LOSS if registry._name == 'loss': - _cfg.pop('NAME') return obj_cls(**dict(_cfg)) return obj_cls(_cfg) diff --git a/pytorch_trainer/utils/config.py b/pytorch_trainer/utils/config.py index 5ed721b..80cc213 100644 --- a/pytorch_trainer/utils/config.py +++ b/pytorch_trainer/utils/config.py @@ -12,8 +12,8 @@ def get_cfg_defaults(): return _C.clone() -def parse_yaml_config(config_path): +def parse_yaml_config(config_path, allow_unsafe=False): cfg = get_cfg_defaults() - cfg.merge_from_file(config_path) + cfg.merge_from_file(config_path, allow_unsafe) cfg.freeze() return cfg diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index a1d0abd..0000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -torch>=1.7.1 -fvcore \ No newline at end of file diff --git a/requirements/dev.txt b/requirements/dev.txt new file mode 100644 index 0000000..512c2ca --- /dev/null +++ b/requirements/dev.txt @@ -0,0 +1,6 @@ +torch>=1.7.1 +fvcore +autopep8 +pre-commit +pytest +pytest-cov \ No newline at end of file diff --git a/requirements-dev.txt b/requirements/runtime.txt similarity index 51% rename from requirements-dev.txt rename to requirements/runtime.txt index 409732d..756f209 100644 --- a/requirements-dev.txt +++ b/requirements/runtime.txt @@ -1,4 +1,3 @@ torch>=1.7.1 fvcore -autopep8 -pre-commit \ No newline at end of file +tensorboard \ No newline at end of file diff --git a/tests/components/configs/base.yaml b/tests/components/configs/base.yaml new file mode 100644 index 0000000..793532d --- /dev/null +++ b/tests/components/configs/base.yaml @@ -0,0 +1,2 @@ +KEY1: "base" +KEY2: "base" diff --git a/tests/components/configs/config.yaml b/tests/components/configs/config.yaml new file mode 100644 index 0000000..fa06651 --- /dev/null +++ b/tests/components/configs/config.yaml @@ -0,0 +1,3 @@ +_BASE_: "base.yaml" +KEY2: "config" +EXPRESSION: !!python/object/apply:eval ["[x ** 2 for x in [1, 2, 3]]"] diff --git a/tests/components/test_config.py b/tests/components/test_config.py new file mode 100644 index 0000000..b9937ea --- /dev/null +++ b/tests/components/test_config.py @@ -0,0 +1,23 @@ +import os.path as osp + +from pytorch_trainer.utils.config import parse_yaml_config + +ROOT_PATH = './test/components/' + + +class TestConfig: + def test_yaml_parser(self): + # [Case] base.yaml + PATH = './configs/base.yaml' + cfg_base = parse_yaml_config(osp.join(ROOT_PATH, PATH)) + + assert cfg_base.KEY1 == 'base' + assert cfg_base.KEY2 == 'base' + + # [Case] config.yaml inherits from base.yaml + PATH = './configs/config.yaml' + cfg = parse_yaml_config(osp.join(ROOT_PATH, PATH), allow_unsafe=True) + + assert cfg.KEY1 == 'base' + assert cfg.KEY2 == 'config' + assert cfg.EXPRESSION == [1, 4, 9] diff --git a/tests/components/test_loss_builder.py b/tests/components/test_loss_builder.py new file mode 100644 index 0000000..06ee792 --- /dev/null +++ b/tests/components/test_loss_builder.py @@ -0,0 +1,10 @@ +from networks.loss import LOSSES +from networks.loss.regular import torch_loss + + +class TestLoss: + def test_loss_list(self): + print(LOSSES._obj_map.keys()) + loss_registry = set(LOSSES._obj_map.keys()) + loss_table = set(torch_loss.keys()) + assert loss_table.issubset(loss_registry) diff --git a/tests/components/test_network_builder.py b/tests/components/test_network_builder.py new file mode 100644 index 0000000..1ee354c --- /dev/null +++ b/tests/components/test_network_builder.py @@ -0,0 +1,110 @@ +import os.path as osp + +import torch +import pytest +import torch.nn as nn + +from pytorch_trainer.utils.config import parse_yaml_config +from networks.classification.builder import build_network + +ROOT_PATH = './configs/networks/classification/' + + +class TestClassification: + @pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml']) + def test_network_builder_with_cfg(self, filename): + """Network tests include backbone/loss builder + """ + FILENAME = filename + cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) + net = build_network(cfg) + assert isinstance(net, nn.Module) + + def test_network_builder_with_keyerror(self): + FILENAME = 'lenet.yaml' + cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) + # Remove keyword: 'NETWORK' + cfg.pop('NETWORK') + with pytest.raises(KeyError) as excinfo: + _ = build_network(cfg) + assert "KeyError" in str(excinfo) + + @pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml']) + def test_network_forward(self, filename): + FILENAME = filename + # Construct network + cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) + net = build_network(cfg) + net.eval() + # Initialize input + if cfg.get('NETWORK'): + n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS') + elif cfg.get('CUSTOM'): + n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS') + else: + assert False + x = torch.rand(4, 3, 32, 32) + # Inference + output_size = net(x).shape + assert output_size == torch.Size([4, n_class]) + + @pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml']) + def test_network_train_step(self, filename): + FILENAME = filename + # Construct network + cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) + net = build_network(cfg) + net.train() + # Initialize input + if cfg.get('NETWORK'): + n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS') + elif cfg.get('CUSTOM'): + n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS') + else: + assert False + x = torch.rand(4, 3, 32, 32) + y = torch.randint(low=0, high=n_class, size=(4,)) + # Training Step + output_loss = net.train_step((x, y)) + assert 'loss' in output_loss + assert 'multi_loss' in output_loss + + @pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml']) + def test_network_val_step(self, filename): + FILENAME = filename + # Construct network + cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) + net = build_network(cfg) + net.eval() + # Initialize input + if cfg.get('NETWORK'): + n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS') + elif cfg.get('CUSTOM'): + n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS') + else: + assert False + x = torch.rand(4, 3, 32, 32) + y = torch.randint(low=0, high=n_class, size=(4,)) + # Training Step + output_loss = net.val_step((x, y)) + assert 'loss' in output_loss + assert 'multi_loss' in output_loss + + def test_backbone_builder_with_keyerror(self): + FILENAME = 'lenet.yaml' + cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) + # Remove keyword: 'BACKBONE' + cfg.NETWORK.pop('BACKBONE') + with pytest.raises(KeyError) as excinfo: + _ = build_network(cfg) + assert "KeyError" in str(excinfo) + + def test_utils_builder_with_keyerror(self): + FILENAME = 'lenet.yaml' + cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) + # Modify keyword name + cfg.defrost() + cfg.NETWORK.BACKBONE.NAME = 'WrongLeNet' + with pytest.raises(KeyError) as excinfo: + _ = build_network(cfg) + assert "KeyError" in str(excinfo) diff --git a/test/trainer/hook_test.py b/tests/trainer/hook_test.py similarity index 100% rename from test/trainer/hook_test.py rename to tests/trainer/hook_test.py diff --git a/test/trainer/trainer_test.py b/tests/trainer/trainer_test.py similarity index 100% rename from test/trainer/trainer_test.py rename to tests/trainer/trainer_test.py