Skip to content

Commit 2172862

Browse files
committed
Log top-k accuracy.
1 parent 0d66cdc commit 2172862

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

spliceai_pytorch/train.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,28 @@
1717
def shuffle(arr):
1818
return np.random.choice(arr, size=len(arr), replace=False)
1919

20+
def top_k_accuracy(pred_probs, labels):
21+
pred_probs, labels = map(lambda x: x.view(-1), [pred_probs, labels]) # Flatten
22+
k = (labels == 1.0).sum().item()
23+
24+
top_k_values, top_k_indices = pred_probs.topk(k)
25+
correct = top_k_values.eq(labels[top_k_indices])
26+
return correct.float().mean()
27+
2028
def train(model, h5f, train_shard_idxs, batch_size, optimizer, criterion):
2129
model.train()
2230
running_output, running_label = [], []
2331

32+
batch_idx = 0
2433
for i, shard_idx in enumerate(shuffle(train_shard_idxs), 1):
2534
X = h5f[f'X{shard_idx}'][:].transpose(0, 2, 1)
2635
Y = h5f[f'Y{shard_idx}'][0, ...]
2736

2837
ds = TensorDataset(torch.from_numpy(X).float(), torch.from_numpy(Y).float())
29-
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
38+
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) # TODO: Check whether drop_last=True?
3039

3140
bar = tqdm.tqdm(loader, leave=False, total=len(loader), desc=f'Shard {i}/{len(train_shard_idxs)}')
32-
for idx, batch in enumerate(bar):
41+
for batch in bar:
3342
X, Y = batch[0].cuda(), batch[1].cuda()
3443
optimizer.zero_grad()
3544
out = model(X) # (batch_size, 5000, 3)
@@ -40,18 +49,26 @@ def train(model, h5f, train_shard_idxs, batch_size, optimizer, criterion):
4049
running_output.append(out.detach().cpu())
4150
running_label.append(Y.detach().cpu())
4251

43-
if idx % 100 == 0:
52+
if batch_idx % 100 == 0:
4453
running_output = torch.cat(running_output, dim=0)
4554
running_label = torch.cat(running_label, dim=0)
4655

56+
running_pred_probs = F.softmax(running_output, dim=-1)
57+
top_k_acc_1 = top_k_accuracy(running_pred_probs[:, :, 1], running_label[:, :, 1])
58+
top_k_acc_2 = top_k_accuracy(running_pred_probs[:, :, 2], running_label[:, :, 2])
59+
4760
loss = criterion(running_output, running_label)
48-
bar.set_postfix(loss=f'{loss.item():.4f}')
61+
bar.set_postfix(loss=f'{loss.item():.4f}', topk_acceptor=f'{top_k_acc_1.item():.4f}', topk_donor=f'{top_k_acc_2.item():.4f}')
4962

5063
running_output, running_label = [], []
5164

5265
wandb.log({
5366
'train/loss': loss.item(),
67+
'train/topk_acceptor': top_k_acc_1.item(),
68+
'train/topk_donor': top_k_acc_2.item(),
5469
})
70+
71+
batch_idx += 1
5572

5673

5774
def validate(model, h5f, val_shard_idxs, batch_size, criterion):
@@ -74,9 +91,20 @@ def validate(model, h5f, val_shard_idxs, batch_size, criterion):
7491
out.append(_out)
7592
label.append(_label)
7693

77-
loss = criterion(torch.cat(out, dim=0), torch.cat(label, dim=0))
94+
out = torch.cat(out, dim=0)
95+
out_pred_probs = F.softmax(out, dim=-1)
96+
label = torch.cat(label, dim=0)
97+
98+
loss = criterion(out, label)
99+
top_k_acc_1 = top_k_accuracy(out_pred_probs[:, :, 1], label[:, :, 1])
100+
top_k_acc_2 = top_k_accuracy(out_pred_probs[:, :, 2], label[:, :, 2])
101+
102+
print(f'Val loss: {loss.item():.4f}, topk_acceptor: {top_k_acc_1.item():.4f}, topk_donor: {top_k_acc_2.item():.4f}')
103+
78104
wandb.log({
79105
'val/loss': loss.item(),
106+
'val/topk_acceptor': top_k_acc_1.item(),
107+
'val/topk_donor': top_k_acc_2.item(),
80108
})
81109

82110
return loss.item()

0 commit comments

Comments
 (0)