Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add loss builder #11

Merged
merged 2 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions configs/networks/classification/lenet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ NETWORK:
NAME: 'LeNet'
BACKBONE:
NAME: 'LeNet'
PARAM: None
BLA:
NAME: 'NO'
NUM_CLASS: 10
LOSS:
NAME: 'CrossEntropyLoss'
weight: null
reduction: 'sum'
ignore_index: -87
9 changes: 8 additions & 1 deletion configs/networks/classification/mynet.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
CUSTOM:
NAME: 'MyLeNet'
N_CLASS: 8
MODEL:
NAME: 'LeNet'
NUM_CLASS: 8
LOSS:
NAME: 'MyLeNetLoss'
weight: null
reduction: 'mean'
ignore_index: -89
21 changes: 15 additions & 6 deletions example/build_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,32 @@
from networks.classification.builder import build_network


def test_inference(cfg_path):
def test_train_inference(cfg_path):
# Case: Gerneral definition of network
cfg = parse_yaml_config(cfg_path)

if cfg.get('NETWORK'):
n_class = cfg.get('NETWORK').get('BACKBONE').get('NUM_CLASS')
elif cfg.get('CUSTOM'):
n_class = cfg.get('CUSTOM').get('MODEL').get('NUM_CLASS')
else:
assert False

net = build_network(cfg=cfg)
print(net)
net.eval()
net.train()
x = torch.rand(4, 3, 32, 32)
print(net(x), '\n')
y = torch.randint(low=0, high=n_class, size=(4,))
print('[LOSS][OUTPUT]', net.train_step((x, y)))
print('[LOSS][PARAMS]', net.criterion.__dict__, '\n')


if __name__ == "__main__":
# Case: Gerneral definition of network
cfg_path = "./configs/networks/classification/lenet.yaml"
print(f"[INFO] config: {cfg_path}")
test_inference(cfg_path)
test_train_inference(cfg_path)

# Case: Custom definition of network
cfg_path = "./configs/networks/classification/mynet.yaml"
print(f"[INFO] config: {cfg_path}")
test_inference(cfg_path)
test_train_inference(cfg_path)
2 changes: 1 addition & 1 deletion networks/classification/backbones/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, cfg):
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.fc3 = nn.Linear(84, cfg.NUM_CLASS)

def init_weights(self):
"""Initialize the weights in your network.
Expand Down
25 changes: 1 addition & 24 deletions networks/classification/builder.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,11 @@
from pytorch_trainer.utils.builder import build
from pytorch_trainer.utils.registry import Registry

BACKBONES = Registry('backbone')
NETWORKS = Registry('network')
CUSTOMS = Registry('custom')


def build(cfg, registry):
"""Build a network.

Args:
cfg (dict): The config of network.
registry (:obj:`Registry`): A registry the module belongs to.

Returns:
nn.Module: A built nn module.
"""
obj_name = cfg.get('NAME')

if isinstance(obj_name, str):
obj_cls = registry.get(obj_name)
if obj_cls is None:
raise KeyError(
f'{obj_name} is not in the {registry.name} registry')
else:
raise TypeError(
f'type must be a str, but got {type(obj_name)}')

return obj_cls(cfg)


def build_backbone(cfg):
return build(cfg, BACKBONES)

Expand Down
9 changes: 5 additions & 4 deletions networks/classification/customs/lenet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ...loss import build_loss
from ..builder import CUSTOMS
from ..networks.base import BaseNetwork

Expand All @@ -17,8 +18,8 @@ def _construct_network(self, cfg):
* Overwrite the parent method if needed.
* Parameters checking isn't involved if customization is utilized.
"""
self.model = LeNet(cfg)
self.criterion = nn.CrossEntropyLoss()
self.model = LeNet(cfg.MODEL)
self.criterion = build_loss(cfg.LOSS)

def get_lr_params(self, group_list):
"""Get LR group for optimizer."""
Expand All @@ -40,14 +41,14 @@ def forward(self, x):
class LeNet(nn.Module):
def __init__(self, cfg):
super(LeNet, self).__init__()
n_class = cfg.get('N_CLASS')
num_class = cfg.get('NUM_CLASS')

self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, n_class)
self.fc3 = nn.Linear(84, num_class)

def init_weights(self):
"""Initialize the weights in your network.
Expand Down
13 changes: 5 additions & 8 deletions networks/classification/networks/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn as nn

from ..builder import BACKBONES, build
from ...loss import build_loss
from ..builder import build_backbone


class INetwork(nn.Module):
Expand Down Expand Up @@ -34,18 +35,14 @@ def forward(self, x):

class BaseNetwork(INetwork):
def __init__(self):
"""[NOTE] Define your network in submodule."""
super(BaseNetwork, self).__init__()

def _construct_network(self, cfg):
"""Construct network from builder.
TODO:
* Loss builder?
"""
"""Construct network from builder."""
if 'BACKBONE' not in cfg:
raise KeyError("Key 'BACKBONE' is not in config.")
self.backbone = build(cfg.BACKBONE, BACKBONES)
self.criterion = None
self.backbone = build_backbone(cfg.BACKBONE)
self.criterion = build_loss(cfg.LOSS)

def freeze(self):
"""Freeze components or layers.
Expand Down
11 changes: 1 addition & 10 deletions networks/classification/networks/lenet.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
import torch.nn as nn

from .base import BaseNetwork
from ..builder import NETWORKS


@NETWORKS.register()
class LeNet(BaseNetwork):
"""
TODO:
* double instantiation: super, here
* Remove criterion after loss builder is ready.
"""

def __init__(self, cfg, **kwargs):
def __init__(self, cfg):
super(LeNet, self).__init__()
self._construct_network(cfg)
self.criterion = nn.CrossEntropyLoss()
7 changes: 7 additions & 0 deletions networks/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .custom import * # noqa: F401,F403
from .builder import LOSSES, build_loss
from .regular import * # noqa: F401,F403

__all__ = [
'LOSSES', 'build_loss'
]
8 changes: 8 additions & 0 deletions networks/loss/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pytorch_trainer.utils.builder import build
from pytorch_trainer.utils.registry import Registry

LOSSES = Registry('loss')


def build_loss(cfg):
return build(cfg, LOSSES)
39 changes: 39 additions & 0 deletions networks/loss/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from .builder import LOSSES


class ILoss:
def __init__(self, **kwargs):
raise NotImplementedError()

def __call__(self, output, label):
raise NotImplementedError()


class BaseLoss(ILoss):
def __init__(self, **kwargs):
"""Parse parameters from config.

Args:
cfg (:obj:`CfgNode`): config dictionary
"""
raise NotImplementedError()

def __call__(self, output, label):
"""Self-defined loss calculation.

Args:
output (torch.Tensor): model prediction
label (torch.Tensor): ground truth
"""
raise NotImplementedError()


@LOSSES.register()
class MyLeNetLoss(BaseLoss):
def __init__(self, **kwargs):
from torch.nn import CrossEntropyLoss
self.__dict__ = kwargs
self.criterion = CrossEntropyLoss(**kwargs)

def __call__(self, output, label):
return self.criterion(output, label)
12 changes: 12 additions & 0 deletions networks/loss/regular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch.nn as nn

from .builder import LOSSES

# [NOTE] Pytorch official API
torch_loss = {
'CrossEntropyLoss': nn.CrossEntropyLoss,
'BCEWithLogitsLoss': nn.BCEWithLogitsLoss
}

for k, v in torch_loss.items():
LOSSES._do_register(k, v)
6 changes: 4 additions & 2 deletions pytorch_trainer/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .config import get_cfg_defaults
from .builder import build
from .registry import Registry

__all__ = [
"Registry",
"get_cfg_defaults"
'Registry',
'get_cfg_defaults',
'build'
]
32 changes: 32 additions & 0 deletions pytorch_trainer/utils/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from copy import deepcopy


def build(cfg, registry):
"""Build a network.

Args:
cfg (dict): The config of network.
registry (:obj:`Registry`): A registry the module belongs to.

Returns:
nn.Module: A built nn module.
"""

_cfg = deepcopy(cfg)
obj_name = _cfg.get('NAME')

if isinstance(obj_name, str):
obj_cls = registry.get(obj_name)
if obj_cls is None:
raise KeyError(
f'{obj_name} is not in the {registry._name} registry')
else:
raise TypeError(
f'type must be a str, but got {type(obj_name)}')

# [Case]: LOSS
if registry._name == 'loss':
_cfg.pop('NAME')
return obj_cls(**dict(_cfg))

return obj_cls(_cfg)