17
17
def shuffle (arr ):
18
18
return np .random .choice (arr , size = len (arr ), replace = False )
19
19
20
+ def top_k_accuracy (pred_probs , labels ):
21
+ pred_probs , labels = map (lambda x : x .view (- 1 ), [pred_probs , labels ]) # Flatten
22
+ k = (labels == 1.0 ).sum ().item ()
23
+
24
+ top_k_values , top_k_indices = pred_probs .topk (k )
25
+ correct = top_k_values .eq (labels [top_k_indices ])
26
+ return correct .float ().mean ()
27
+
20
28
def train (model , h5f , train_shard_idxs , batch_size , optimizer , criterion ):
21
29
model .train ()
22
30
running_output , running_label = [], []
23
31
32
+ batch_idx = 0
24
33
for i , shard_idx in enumerate (shuffle (train_shard_idxs ), 1 ):
25
34
X = h5f [f'X{ shard_idx } ' ][:].transpose (0 , 2 , 1 )
26
35
Y = h5f [f'Y{ shard_idx } ' ][0 , ...]
27
36
28
37
ds = TensorDataset (torch .from_numpy (X ).float (), torch .from_numpy (Y ).float ())
29
- loader = DataLoader (ds , batch_size = batch_size , shuffle = True , num_workers = 8 , pin_memory = True )
38
+ loader = DataLoader (ds , batch_size = batch_size , shuffle = True , num_workers = 8 , pin_memory = True ) # TODO: Check whether drop_last=True?
30
39
31
40
bar = tqdm .tqdm (loader , leave = False , total = len (loader ), desc = f'Shard { i } /{ len (train_shard_idxs )} ' )
32
- for idx , batch in enumerate ( bar ) :
41
+ for batch in bar :
33
42
X , Y = batch [0 ].cuda (), batch [1 ].cuda ()
34
43
optimizer .zero_grad ()
35
44
out = model (X ) # (batch_size, 5000, 3)
@@ -40,18 +49,26 @@ def train(model, h5f, train_shard_idxs, batch_size, optimizer, criterion):
40
49
running_output .append (out .detach ().cpu ())
41
50
running_label .append (Y .detach ().cpu ())
42
51
43
- if idx % 100 == 0 :
52
+ if batch_idx % 100 == 0 :
44
53
running_output = torch .cat (running_output , dim = 0 )
45
54
running_label = torch .cat (running_label , dim = 0 )
46
55
56
+ running_pred_probs = F .softmax (running_output , dim = - 1 )
57
+ top_k_acc_1 = top_k_accuracy (running_pred_probs [:, :, 1 ], running_label [:, :, 1 ])
58
+ top_k_acc_2 = top_k_accuracy (running_pred_probs [:, :, 2 ], running_label [:, :, 2 ])
59
+
47
60
loss = criterion (running_output , running_label )
48
- bar .set_postfix (loss = f'{ loss .item ():.4f} ' )
61
+ bar .set_postfix (loss = f'{ loss .item ():.4f} ' , topk_acceptor = f' { top_k_acc_1 . item ():.4f } ' , topk_donor = f' { top_k_acc_2 . item ():.4f } ' )
49
62
50
63
running_output , running_label = [], []
51
64
52
65
wandb .log ({
53
66
'train/loss' : loss .item (),
67
+ 'train/topk_acceptor' : top_k_acc_1 .item (),
68
+ 'train/topk_donor' : top_k_acc_2 .item (),
54
69
})
70
+
71
+ batch_idx += 1
55
72
56
73
57
74
def validate (model , h5f , val_shard_idxs , batch_size , criterion ):
@@ -74,9 +91,20 @@ def validate(model, h5f, val_shard_idxs, batch_size, criterion):
74
91
out .append (_out )
75
92
label .append (_label )
76
93
77
- loss = criterion (torch .cat (out , dim = 0 ), torch .cat (label , dim = 0 ))
94
+ out = torch .cat (out , dim = 0 )
95
+ out_pred_probs = F .softmax (out , dim = - 1 )
96
+ label = torch .cat (label , dim = 0 )
97
+
98
+ loss = criterion (out , label )
99
+ top_k_acc_1 = top_k_accuracy (out_pred_probs [:, :, 1 ], label [:, :, 1 ])
100
+ top_k_acc_2 = top_k_accuracy (out_pred_probs [:, :, 2 ], label [:, :, 2 ])
101
+
102
+ print (f'Val loss: { loss .item ():.4f} , topk_acceptor: { top_k_acc_1 .item ():.4f} , topk_donor: { top_k_acc_2 .item ():.4f} ' )
103
+
78
104
wandb .log ({
79
105
'val/loss' : loss .item (),
106
+ 'val/topk_acceptor' : top_k_acc_1 .item (),
107
+ 'val/topk_donor' : top_k_acc_2 .item (),
80
108
})
81
109
82
110
return loss .item ()
0 commit comments