Skip to content

Commit

Permalink
[Fix] update build loss api (#3587)
Browse files Browse the repository at this point in the history
## Motivation

Use `MODELS.build` instead of `build_loss`

## Modification

Please briefly describe what modification is made in this PR.
  • Loading branch information
xiexinch authored Mar 8, 2024
1 parent be687fc commit 5465118
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
6 changes: 3 additions & 3 deletions mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from mmengine.model import BaseModule
from torch import Tensor

from mmseg.registry import MODELS
from mmseg.structures import build_pixel_sampler
from mmseg.utils import ConfigType, SampleList
from ..builder import build_loss
from ..losses import accuracy
from ..utils import resize

Expand Down Expand Up @@ -140,11 +140,11 @@ def __init__(self,
self.threshold = threshold

if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
self.loss_decode = MODELS.build(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
self.loss_decode.append(MODELS.build(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
Expand Down
3 changes: 1 addition & 2 deletions mmseg/models/decode_heads/enc_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from mmseg.registry import MODELS
from mmseg.utils import ConfigType, SampleList
from ..builder import build_loss
from ..utils import Encoding, resize
from .decode_head import BaseDecodeHead

Expand Down Expand Up @@ -128,7 +127,7 @@ def __init__(self,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.use_se_loss:
self.loss_se_decode = build_loss(loss_se_decode)
self.loss_se_decode = MODELS.build(loss_se_decode)
self.se_layer = nn.Linear(self.channels, self.num_classes)

def forward(self, inputs):
Expand Down
5 changes: 2 additions & 3 deletions mmseg/models/decode_heads/vpd_depth_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from mmseg.registry import MODELS
from mmseg.utils import SampleList
from ..builder import build_loss
from ..utils import resize
from .decode_head import BaseDecodeHead

Expand Down Expand Up @@ -184,11 +183,11 @@ def __init__(

# build loss
if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
self.loss_decode = MODELS.build(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
self.loss_decode.append(MODELS.build(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
Expand Down

0 comments on commit 5465118

Please sign in to comment.