diff --git a/pyha_analyzer/train.py b/pyha_analyzer/train.py index 5d4d26b..2249c96 100644 --- a/pyha_analyzer/train.py +++ b/pyha_analyzer/train.py @@ -224,7 +224,7 @@ def valid(model: Any, # softmax predictions - log_pred = F.sigmoid(torch.cat(log_pred)).to(cfg.device) + log_pred = F.sigmoid(torch.cat(log_pred)).cpu() #.to(cfg.device) dataset = data_loader.dataset # type: ignore cmap, smap = map_metric(log_pred, torch.cat(log_label), dataset.class_dist)