Skip to content

Commit

Permalink
fix bugs in medcpt and basebert
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoziheng committed May 24, 2024
1 parent 9fa0f68 commit 3d8fede
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 2 deletions.
Binary file modified docs/.DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions evaluate/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def inference(model, text_encoder, device, testset, testloader, nib_dir):
queries_ls = []
for labels_ls, n1n2 in zip(split_labels, split_n1n2): # convert list of texts to list of embeds
queries_ls.append(text_encoder(labels_ls, modality))

torch.cuda.empty_cache()

# for each batch of patches, query with all labels
for patches, y1y2_x1x2_z1z2_ls in zip(batched_patches, batched_y1y2_x1x2_z1z2): # [b, c, h, w, d]
Expand Down
3 changes: 2 additions & 1 deletion model/base_bert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn
import torch

from transformers import BertModel, AutoTokenizer

Expand All @@ -16,7 +17,7 @@ def forward(self, text, modality):
padding=True,
return_tensors='pt',
max_length=64,
)
).to(device=torch.cuda.current_device())

text_feature = self.model(**encoded).last_hidden_state[:, 0, :]
modality_feature = self.modality_embed(modality)
Expand Down
2 changes: 2 additions & 0 deletions model/maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def forward(self, queries, image_input):

# Infer / Evaluate Forward ------------------------------------------------------------
if isinstance(queries, List):
del image_input
torch.cuda.empty_cache()
logits = self.infer_forward(queries, image_embedding, pos, per_pixel_embedding_ls)

# Train Forward -----------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion model/med_cpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn
import torch

from transformers import AutoModel, AutoTokenizer

Expand All @@ -16,7 +17,7 @@ def forward(self, text, modality):
padding=True,
return_tensors='pt',
max_length=64,
)
).to(device=torch.cuda.current_device())

text_feature = self.model(**encoded).last_hidden_state[:, 0, :]
modality_feature = self.modality_embed(modality)
Expand Down

0 comments on commit 3d8fede

Please sign in to comment.