From be8c0a70ba20b2defb8ecac64ad36d27f5171715 Mon Sep 17 00:00:00 2001 From: mousyball Date: Thu, 28 Jan 2021 17:36:24 +0800 Subject: [PATCH 1/2] [feat] Add general builder to trainer --- pytorch_trainer/utils/__init__.py | 6 ++++-- pytorch_trainer/utils/builder.py | 32 +++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 pytorch_trainer/utils/builder.py diff --git a/pytorch_trainer/utils/__init__.py b/pytorch_trainer/utils/__init__.py index 2f4a9ce..00fdd9e 100644 --- a/pytorch_trainer/utils/__init__.py +++ b/pytorch_trainer/utils/__init__.py @@ -1,7 +1,9 @@ from .config import get_cfg_defaults +from .builder import build from .registry import Registry __all__ = [ - "Registry", - "get_cfg_defaults" + 'Registry', + 'get_cfg_defaults', + 'build' ] diff --git a/pytorch_trainer/utils/builder.py b/pytorch_trainer/utils/builder.py new file mode 100644 index 0000000..1351f5e --- /dev/null +++ b/pytorch_trainer/utils/builder.py @@ -0,0 +1,32 @@ +from copy import deepcopy + + +def build(cfg, registry): + """Build a network. + + Args: + cfg (dict): The config of network. + registry (:obj:`Registry`): A registry the module belongs to. + + Returns: + nn.Module: A built nn module. + """ + + _cfg = deepcopy(cfg) + obj_name = _cfg.get('NAME') + + if isinstance(obj_name, str): + obj_cls = registry.get(obj_name) + if obj_cls is None: + raise KeyError( + f'{obj_name} is not in the {registry._name} registry') + else: + raise TypeError( + f'type must be a str, but got {type(obj_name)}') + + # [Case]: LOSS + if registry._name == 'loss': + _cfg.pop('NAME') + return obj_cls(**dict(_cfg)) + + return obj_cls(_cfg) From 2c931508e6e8950143e8ca6780a3c32f2b8c0239 Mon Sep 17 00:00:00 2001 From: mousyball Date: Thu, 28 Jan 2021 17:42:45 +0800 Subject: [PATCH 2/2] [feat] Loss builder Use loss builder for regular loss from torch and custom loss. --- configs/networks/classification/lenet.yaml | 9 +++-- configs/networks/classification/mynet.yaml | 9 ++++- example/build_lenet.py | 21 ++++++++---- networks/classification/backbones/lenet.py | 2 +- networks/classification/builder.py | 25 +------------- networks/classification/customs/lenet.py | 9 ++--- networks/classification/networks/base.py | 13 +++----- networks/classification/networks/lenet.py | 11 +----- networks/loss/__init__.py | 7 ++++ networks/loss/builder.py | 8 +++++ networks/loss/custom.py | 39 ++++++++++++++++++++++ networks/loss/regular.py | 12 +++++++ 12 files changed, 108 insertions(+), 57 deletions(-) create mode 100644 networks/loss/__init__.py create mode 100644 networks/loss/builder.py create mode 100644 networks/loss/custom.py create mode 100644 networks/loss/regular.py diff --git a/configs/networks/classification/lenet.yaml b/configs/networks/classification/lenet.yaml index 9fb27c0..bd7d09a 100644 --- a/configs/networks/classification/lenet.yaml +++ b/configs/networks/classification/lenet.yaml @@ -2,6 +2,9 @@ NETWORK: NAME: 'LeNet' BACKBONE: NAME: 'LeNet' - PARAM: None - BLA: - NAME: 'NO' + NUM_CLASS: 10 + LOSS: + NAME: 'CrossEntropyLoss' + weight: null + reduction: 'sum' + ignore_index: -87 diff --git a/configs/networks/classification/mynet.yaml b/configs/networks/classification/mynet.yaml index a85ba9f..777bf30 100644 --- a/configs/networks/classification/mynet.yaml +++ b/configs/networks/classification/mynet.yaml @@ -1,3 +1,10 @@ CUSTOM: NAME: 'MyLeNet' - N_CLASS: 8 + MODEL: + NAME: 'LeNet' + NUM_CLASS: 8 + LOSS: + NAME: 'MyLeNetLoss' + weight: null + reduction: 'mean' + ignore_index: -89 diff --git a/example/build_lenet.py b/example/build_lenet.py index eb39210..3036e9c 100644 --- a/example/build_lenet.py +++ b/example/build_lenet.py @@ -4,23 +4,32 @@ from networks.classification.builder import build_network -def test_inference(cfg_path): +def test_train_inference(cfg_path): # Case: Gerneral definition of network cfg = parse_yaml_config(cfg_path) + + if cfg.get('NETWORK'): + n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS') + elif cfg.get('CUSTOM'): + n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS') + else: + assert False + net = build_network(cfg=cfg) - print(net) - net.eval() + net.train() x = torch.rand(4, 3, 32, 32) - print(net(x), '\n') + y = torch.randint(low=0, high=n_class, size=(4,)) + print('[LOSS][OUTPUT]', net.train_step((x, y))) + print('[LOSS][PARAMS]', net.criterion.__dict__, '\n') if __name__ == "__main__": # Case: Gerneral definition of network cfg_path = "./configs/networks/classification/lenet.yaml" print(f"[INFO] config: {cfg_path}") - test_inference(cfg_path) + test_train_inference(cfg_path) # Case: Custom definition of network cfg_path = "./configs/networks/classification/mynet.yaml" print(f"[INFO] config: {cfg_path}") - test_inference(cfg_path) + test_train_inference(cfg_path) diff --git a/networks/classification/backbones/lenet.py b/networks/classification/backbones/lenet.py index 0192de1..9325046 100644 --- a/networks/classification/backbones/lenet.py +++ b/networks/classification/backbones/lenet.py @@ -14,7 +14,7 @@ def __init__(self, cfg): self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) + self.fc3 = nn.Linear(84, cfg.NUM_CLASS) def init_weights(self): """Initialize the weights in your network. diff --git a/networks/classification/builder.py b/networks/classification/builder.py index 3bc0e4f..c1206a3 100644 --- a/networks/classification/builder.py +++ b/networks/classification/builder.py @@ -1,3 +1,4 @@ +from pytorch_trainer.utils.builder import build from pytorch_trainer.utils.registry import Registry BACKBONES = Registry('backbone') @@ -5,30 +6,6 @@ CUSTOMS = Registry('custom') -def build(cfg, registry): - """Build a network. - - Args: - cfg (dict): The config of network. - registry (:obj:`Registry`): A registry the module belongs to. - - Returns: - nn.Module: A built nn module. - """ - obj_name = cfg.get('NAME') - - if isinstance(obj_name, str): - obj_cls = registry.get(obj_name) - if obj_cls is None: - raise KeyError( - f'{obj_name} is not in the {registry.name} registry') - else: - raise TypeError( - f'type must be a str, but got {type(obj_name)}') - - return obj_cls(cfg) - - def build_backbone(cfg): return build(cfg, BACKBONES) diff --git a/networks/classification/customs/lenet.py b/networks/classification/customs/lenet.py index 10f9d1e..301af29 100644 --- a/networks/classification/customs/lenet.py +++ b/networks/classification/customs/lenet.py @@ -1,6 +1,7 @@ import torch.nn as nn import torch.nn.functional as F +from ...loss import build_loss from ..builder import CUSTOMS from ..networks.base import BaseNetwork @@ -17,8 +18,8 @@ def _construct_network(self, cfg): * Overwrite the parent method if needed. * Parameters checking isn't involved if customization is utilized. """ - self.model = LeNet(cfg) - self.criterion = nn.CrossEntropyLoss() + self.model = LeNet(cfg.MODEL) + self.criterion = build_loss(cfg.LOSS) def get_lr_params(self, group_list): """Get LR group for optimizer.""" @@ -40,14 +41,14 @@ def forward(self, x): class LeNet(nn.Module): def __init__(self, cfg): super(LeNet, self).__init__() - n_class = cfg.get('N_CLASS') + num_class = cfg.get('NUM_CLASS') self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, n_class) + self.fc3 = nn.Linear(84, num_class) def init_weights(self): """Initialize the weights in your network. diff --git a/networks/classification/networks/base.py b/networks/classification/networks/base.py index 526d56d..b7ab37e 100644 --- a/networks/classification/networks/base.py +++ b/networks/classification/networks/base.py @@ -1,6 +1,7 @@ import torch.nn as nn -from ..builder import BACKBONES, build +from ...loss import build_loss +from ..builder import build_backbone class INetwork(nn.Module): @@ -34,18 +35,14 @@ def forward(self, x): class BaseNetwork(INetwork): def __init__(self): - """[NOTE] Define your network in submodule.""" super(BaseNetwork, self).__init__() def _construct_network(self, cfg): - """Construct network from builder. - TODO: - * Loss builder? - """ + """Construct network from builder.""" if 'BACKBONE' not in cfg: raise KeyError("Key 'BACKBONE' is not in config.") - self.backbone = build(cfg.BACKBONE, BACKBONES) - self.criterion = None + self.backbone = build_backbone(cfg.BACKBONE) + self.criterion = build_loss(cfg.LOSS) def freeze(self): """Freeze components or layers. diff --git a/networks/classification/networks/lenet.py b/networks/classification/networks/lenet.py index c5c2d4d..aaf74a5 100644 --- a/networks/classification/networks/lenet.py +++ b/networks/classification/networks/lenet.py @@ -1,18 +1,9 @@ -import torch.nn as nn - from .base import BaseNetwork from ..builder import NETWORKS @NETWORKS.register() class LeNet(BaseNetwork): - """ - TODO: - * double instantiation: super, here - * Remove criterion after loss builder is ready. - """ - - def __init__(self, cfg, **kwargs): + def __init__(self, cfg): super(LeNet, self).__init__() self._construct_network(cfg) - self.criterion = nn.CrossEntropyLoss() diff --git a/networks/loss/__init__.py b/networks/loss/__init__.py new file mode 100644 index 0000000..a3bc34a --- /dev/null +++ b/networks/loss/__init__.py @@ -0,0 +1,7 @@ +from .custom import * # noqa: F401,F403 +from .builder import LOSSES, build_loss +from .regular import * # noqa: F401,F403 + +__all__ = [ + 'LOSSES', 'build_loss' +] diff --git a/networks/loss/builder.py b/networks/loss/builder.py new file mode 100644 index 0000000..f4901df --- /dev/null +++ b/networks/loss/builder.py @@ -0,0 +1,8 @@ +from pytorch_trainer.utils.builder import build +from pytorch_trainer.utils.registry import Registry + +LOSSES = Registry('loss') + + +def build_loss(cfg): + return build(cfg, LOSSES) diff --git a/networks/loss/custom.py b/networks/loss/custom.py new file mode 100644 index 0000000..97a90a7 --- /dev/null +++ b/networks/loss/custom.py @@ -0,0 +1,39 @@ +from .builder import LOSSES + + +class ILoss: + def __init__(self, **kwargs): + raise NotImplementedError() + + def __call__(self, output, label): + raise NotImplementedError() + + +class BaseLoss(ILoss): + def __init__(self, **kwargs): + """Parse parameters from config. + + Args: + cfg (:obj:`CfgNode`): config dictionary + """ + raise NotImplementedError() + + def __call__(self, output, label): + """Self-defined loss calculation. + + Args: + output (torch.Tensor): model prediction + label (torch.Tensor): ground truth + """ + raise NotImplementedError() + + +@LOSSES.register() +class MyLeNetLoss(BaseLoss): + def __init__(self, **kwargs): + from torch.nn import CrossEntropyLoss + self.__dict__ = kwargs + self.criterion = CrossEntropyLoss(**kwargs) + + def __call__(self, output, label): + return self.criterion(output, label) diff --git a/networks/loss/regular.py b/networks/loss/regular.py new file mode 100644 index 0000000..8f04172 --- /dev/null +++ b/networks/loss/regular.py @@ -0,0 +1,12 @@ +import torch.nn as nn + +from .builder import LOSSES + +# [NOTE] Pytorch official API +torch_loss = { + 'CrossEntropyLoss': nn.CrossEntropyLoss, + 'BCEWithLogitsLoss': nn.BCEWithLogitsLoss +} + +for k, v in torch_loss.items(): + LOSSES._do_register(k, v)