-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathartifact_detector_model.py
115 lines (100 loc) · 3.77 KB
/
artifact_detector_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchvision import models
from torchmetrics.functional import auroc
from sklearn.metrics import balanced_accuracy_score
import numpy as np
MARKER_NAMES = [
"circle marker",
"triangle marker",
"breast implant",
"devices",
"compression",
]
class Multilabel_ArtifactDetector(pl.LightningModule):
def __init__(self, num_classes=5, learning_rate=0.0001):
super().__init__()
self.save_hyperparameters()
self.num_classes = num_classes
self.lr = learning_rate
self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_features = self.model.fc.in_features
self.model.fc = nn.Linear(num_features, self.num_classes)
def forward(self, x):
return torch.sigmoid(self.model(x))
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
return optimizer
def process_batch(self, batch):
img, lab = batch["image"], batch["label"]
out = self.model(img)
prd = torch.sigmoid(out)
loss = torch.nn.functional.binary_cross_entropy_with_logits(out, lab.float())
return loss, prd, lab
def on_train_epoch_start(self):
self.train_preds = []
self.train_trgts = []
def training_step(self, batch, batch_idx):
loss, prd, lab = self.process_batch(batch)
self.log("train_loss", loss, batch_size=lab.shape[0])
self.train_preds.append(prd.detach().cpu())
self.train_trgts.append(lab.detach().cpu())
return loss
def on_train_epoch_end(self):
self.train_preds = torch.cat(self.train_preds, dim=0)
self.train_trgts = torch.cat(self.train_trgts, dim=0)
auc = auroc(
self.train_preds,
self.train_trgts,
average="macro",
task="multilabel",
num_labels=self.num_classes,
)
self.log("train_auc", auc)
def on_validation_epoch_start(self):
self.val_preds = []
self.val_trgts = []
def validation_step(self, batch, batch_idx):
loss, prd, lab = self.process_batch(batch)
self.log("val_loss", loss, batch_size=lab.shape[0])
self.val_preds.append(prd.detach().cpu())
self.val_trgts.append(lab.detach().cpu())
def on_validation_epoch_end(self):
self.val_preds = torch.cat(self.val_preds, dim=0)
self.val_trgts = torch.cat(self.val_trgts, dim=0)
auc = auroc(
self.val_preds,
self.val_trgts,
average="macro",
task="multilabel",
num_labels=self.num_classes,
)
self.log("val_auc", auc)
all_bal_acc = [
balanced_accuracy_score(self.val_preds[:, i] > 0.5, self.val_trgts[:, i])
for i in range(5)
]
[self.log(f"val_bal_acc_{i}", all_bal_acc[i]) for i in range(5)]
self.log("val_bal_acc", np.asarray(all_bal_acc).mean())
def on_test_epoch_start(self):
self.test_preds = []
self.test_trgts = []
self.test_image_ids = []
def test_step(self, batch, batch_idx):
loss, prd, lab = self.process_batch(batch)
self.log("test_loss", loss, batch_size=lab.shape[0])
self.test_preds.append(prd.detach().cpu())
self.test_trgts.append(lab.detach().cpu())
self.test_image_ids.append(batch["image_id"])
def on_test_epoch_end(self):
self.test_preds = torch.cat(self.test_preds, dim=0)
self.test_trgts = torch.cat(self.test_trgts, dim=0)
auc = auroc(
self.test_preds,
self.test_trgts,
average="macro",
task="multilabel",
num_labels=self.num_classes,
)
self.log("test_auc", auc)