Skip to content

Commit

Permalink
check loss_decode
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Apr 15, 2024
1 parent f2e2893 commit 0399b71
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions train/src/init_default_cfg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import supervisely as sly
import sly_globals as g

def init_default_cfg_params(state):
Expand Down Expand Up @@ -41,31 +42,44 @@ def init_default_cfg_params(state):

def init_default_cfg_args(cfg):
params = []
loss_decode_type, loss_decode_weight = None, None
if hasattr(cfg.model, "decode_head") and cfg.model.decode_head is not None:
decode_is_list = isinstance(cfg.model.decode_head, list)
if decode_is_list and hasattr(cfg.model.decode_head[0], "loss_decode") or not decode_is_list and hasattr(cfg.model.decode_head, "loss_decode"):
loss_decode = cfg.model.decode_head[0].loss_decode if decode_is_list else cfg.model.decode_head.loss_decode
try:
loss_decode_type = loss_decode.type
loss_decode_weight = loss_decode.loss_weight
except Exception:
try:
loss_decode_type = loss_decode[0].type
loss_decode_weight = loss_decode[0].loss_weight
except Exception:
sly.logger.warn("Can't get loss_decode type and weight")
if loss_decode_type is not None and loss_decode_weight is not None:
params.extend([
{"field": "state.decodeHeadLoss", "payload": loss_decode_type},
{"field": "state.decodeHeadLossWeight", "payload": loss_decode_weight}
])
loss_decode_type, loss_decode_weight = None, None
if hasattr(cfg.model, "auxiliary_head") and cfg.model.auxiliary_head is not None:
decode_is_list = isinstance(cfg.model.auxiliary_head, list)
loss_decode = cfg.model.auxiliary_head[0].loss_decode if decode_is_list else cfg.model.auxiliary_head.loss_decode
try:
loss_decode_type = loss_decode.type
loss_decode_weight = loss_decode.loss_weight
except Exception:
try:
loss_decode_type = loss_decode[0].type
loss_decode_weight = loss_decode[0].loss_weight
except Exception:
sly.logger.warn("Can't get loss_decode type and weight")
if loss_decode_type is not None and loss_decode_weight is not None:
params.extend([
{
"field": "state.decodeHeadLoss",
"payload": cfg.model.decode_head[0].loss_decode.type if decode_is_list else cfg.model.decode_head.loss_decode.type
},
{
"field": "state.decodeHeadLossWeight",
"payload": cfg.model.decode_head[0].loss_decode.loss_weight if decode_is_list else cfg.model.decode_head.loss_decode.loss_weight
},
{"field": "state.auxiliaryHeadLoss", "payload": loss_decode_type},
{"field": "state.auxiliaryHeadLossWeight", "payload": loss_decode_weight}
])
if hasattr(cfg.model, "auxiliary_head") and cfg.model.auxiliary_head is not None:
params.extend([
{
"field": "state.auxiliaryHeadLoss",
"payload": cfg.model.auxiliary_head[0].loss_decode.type if isinstance(cfg.model.auxiliary_head,
list) else cfg.model.auxiliary_head.loss_decode.type
},
{
"field": "state.auxiliaryHeadLossWeight",
"payload": cfg.model.auxiliary_head[0].loss_decode.loss_weight if isinstance(cfg.model.auxiliary_head, list) else cfg.model.auxiliary_head.loss_decode.loss_weight
}
])

if hasattr(cfg.data, "samples_per_gpu"):
params.extend([{
"field": "state.batchSizePerGPU",
Expand Down

0 comments on commit 0399b71

Please sign in to comment.