-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from mousyball/network
[test] Test for config, network and loss builder
- Loading branch information
Showing
19 changed files
with
181 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[run] | ||
omit = | ||
|
||
[report] | ||
exclude_lines = | ||
pragma: no cover | ||
def __repr__ | ||
if self.debug: | ||
if settings.DEBUG | ||
raise AssertionError | ||
raise NotImplementedError | ||
if 0: | ||
if __name__ == .__main__.: | ||
assert False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
from .config import get_cfg_defaults | ||
from .config import get_cfg_defaults, parse_yaml_config | ||
from .builder import build | ||
from .registry import Registry | ||
|
||
__all__ = [ | ||
'Registry', | ||
'get_cfg_defaults', | ||
'parse_yaml_config', | ||
'build' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
torch>=1.7.1 | ||
fvcore | ||
autopep8 | ||
pre-commit | ||
pytest | ||
pytest-cov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
torch>=1.7.1 | ||
fvcore | ||
autopep8 | ||
pre-commit | ||
tensorboard |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
KEY1: "base" | ||
KEY2: "base" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
_BASE_: "base.yaml" | ||
KEY2: "config" | ||
EXPRESSION: !!python/object/apply:eval ["[x ** 2 for x in [1, 2, 3]]"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os.path as osp | ||
|
||
from pytorch_trainer.utils.config import parse_yaml_config | ||
|
||
ROOT_PATH = './test/components/' | ||
|
||
|
||
class TestConfig: | ||
def test_yaml_parser(self): | ||
# [Case] base.yaml | ||
PATH = './configs/base.yaml' | ||
cfg_base = parse_yaml_config(osp.join(ROOT_PATH, PATH)) | ||
|
||
assert cfg_base.KEY1 == 'base' | ||
assert cfg_base.KEY2 == 'base' | ||
|
||
# [Case] config.yaml inherits from base.yaml | ||
PATH = './configs/config.yaml' | ||
cfg = parse_yaml_config(osp.join(ROOT_PATH, PATH), allow_unsafe=True) | ||
|
||
assert cfg.KEY1 == 'base' | ||
assert cfg.KEY2 == 'config' | ||
assert cfg.EXPRESSION == [1, 4, 9] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from networks.loss import LOSSES | ||
from networks.loss.regular import torch_loss | ||
|
||
|
||
class TestLoss: | ||
def test_loss_list(self): | ||
print(LOSSES._obj_map.keys()) | ||
loss_registry = set(LOSSES._obj_map.keys()) | ||
loss_table = set(torch_loss.keys()) | ||
assert loss_table.issubset(loss_registry) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import os.path as osp | ||
|
||
import torch | ||
import pytest | ||
import torch.nn as nn | ||
|
||
from pytorch_trainer.utils.config import parse_yaml_config | ||
from networks.classification.builder import build_network | ||
|
||
ROOT_PATH = './configs/networks/classification/' | ||
|
||
|
||
class TestClassification: | ||
@pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml']) | ||
def test_network_builder_with_cfg(self, filename): | ||
"""Network tests include backbone/loss builder | ||
""" | ||
FILENAME = filename | ||
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) | ||
net = build_network(cfg) | ||
assert isinstance(net, nn.Module) | ||
|
||
def test_network_builder_with_keyerror(self): | ||
FILENAME = 'lenet.yaml' | ||
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) | ||
# Remove keyword: 'NETWORK' | ||
cfg.pop('NETWORK') | ||
with pytest.raises(KeyError) as excinfo: | ||
_ = build_network(cfg) | ||
assert "KeyError" in str(excinfo) | ||
|
||
@pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml']) | ||
def test_network_forward(self, filename): | ||
FILENAME = filename | ||
# Construct network | ||
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) | ||
net = build_network(cfg) | ||
net.eval() | ||
# Initialize input | ||
if cfg.get('NETWORK'): | ||
n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS') | ||
elif cfg.get('CUSTOM'): | ||
n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS') | ||
else: | ||
assert False | ||
x = torch.rand(4, 3, 32, 32) | ||
# Inference | ||
output_size = net(x).shape | ||
assert output_size == torch.Size([4, n_class]) | ||
|
||
@pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml']) | ||
def test_network_train_step(self, filename): | ||
FILENAME = filename | ||
# Construct network | ||
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) | ||
net = build_network(cfg) | ||
net.train() | ||
# Initialize input | ||
if cfg.get('NETWORK'): | ||
n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS') | ||
elif cfg.get('CUSTOM'): | ||
n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS') | ||
else: | ||
assert False | ||
x = torch.rand(4, 3, 32, 32) | ||
y = torch.randint(low=0, high=n_class, size=(4,)) | ||
# Training Step | ||
output_loss = net.train_step((x, y)) | ||
assert 'loss' in output_loss | ||
assert 'multi_loss' in output_loss | ||
|
||
@pytest.mark.parametrize("filename", ['lenet.yaml', 'mynet.yaml']) | ||
def test_network_val_step(self, filename): | ||
FILENAME = filename | ||
# Construct network | ||
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) | ||
net = build_network(cfg) | ||
net.eval() | ||
# Initialize input | ||
if cfg.get('NETWORK'): | ||
n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS') | ||
elif cfg.get('CUSTOM'): | ||
n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS') | ||
else: | ||
assert False | ||
x = torch.rand(4, 3, 32, 32) | ||
y = torch.randint(low=0, high=n_class, size=(4,)) | ||
# Training Step | ||
output_loss = net.val_step((x, y)) | ||
assert 'loss' in output_loss | ||
assert 'multi_loss' in output_loss | ||
|
||
def test_backbone_builder_with_keyerror(self): | ||
FILENAME = 'lenet.yaml' | ||
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) | ||
# Remove keyword: 'BACKBONE' | ||
cfg.NETWORK.pop('BACKBONE') | ||
with pytest.raises(KeyError) as excinfo: | ||
_ = build_network(cfg) | ||
assert "KeyError" in str(excinfo) | ||
|
||
def test_utils_builder_with_keyerror(self): | ||
FILENAME = 'lenet.yaml' | ||
cfg = parse_yaml_config(osp.join(ROOT_PATH, FILENAME)) | ||
# Modify keyword name | ||
cfg.defrost() | ||
cfg.NETWORK.BACKBONE.NAME = 'WrongLeNet' | ||
with pytest.raises(KeyError) as excinfo: | ||
_ = build_network(cfg) | ||
assert "KeyError" in str(excinfo) |
File renamed without changes.
File renamed without changes.