From abdd95496af5fccfef61b1158b97b2ad2345ecd2 Mon Sep 17 00:00:00 2001 From: zhaoziheng <565295081@qq.com> Date: Fri, 24 May 2024 14:05:56 +0800 Subject: [PATCH] fix bugs in medcpt and basebert --- evaluate/inference_engine.py | 2 ++ model/base_bert.py | 3 ++- model/maskformer.py | 2 ++ model/med_cpt.py | 3 ++- 4 files changed, 8 insertions(+), 2 deletions(-) diff --git a/evaluate/inference_engine.py b/evaluate/inference_engine.py index c7a8676..e270b15 100644 --- a/evaluate/inference_engine.py +++ b/evaluate/inference_engine.py @@ -88,6 +88,8 @@ def inference(model, text_encoder, device, testset, testloader, nib_dir): queries_ls = [] for labels_ls, n1n2 in zip(split_labels, split_n1n2): # convert list of texts to list of embeds queries_ls.append(text_encoder(labels_ls, modality)) + + torch.cuda.empty_cache() # for each batch of patches, query with all labels for patches, y1y2_x1x2_z1z2_ls in zip(batched_patches, batched_y1y2_x1x2_z1z2): # [b, c, h, w, d] diff --git a/model/base_bert.py b/model/base_bert.py index f18f926..96edabc 100644 --- a/model/base_bert.py +++ b/model/base_bert.py @@ -1,4 +1,5 @@ import torch.nn as nn +import torch from transformers import BertModel, AutoTokenizer @@ -16,7 +17,7 @@ def forward(self, text, modality): padding=True, return_tensors='pt', max_length=64, - ) + ).to(device=torch.cuda.current_device()) text_feature = self.model(**encoded).last_hidden_state[:, 0, :] modality_feature = self.modality_embed(modality) diff --git a/model/maskformer.py b/model/maskformer.py index 392514b..8e279ea 100644 --- a/model/maskformer.py +++ b/model/maskformer.py @@ -323,6 +323,8 @@ def forward(self, queries, image_input): # Infer / Evaluate Forward ------------------------------------------------------------ if isinstance(queries, List): + del image_input + torch.cuda.empty_cache() logits = self.infer_forward(queries, image_embedding, pos, per_pixel_embedding_ls) # Train Forward ----------------------------------------------------------------------- diff --git a/model/med_cpt.py b/model/med_cpt.py index 49fade5..1d3d4b7 100644 --- a/model/med_cpt.py +++ b/model/med_cpt.py @@ -1,4 +1,5 @@ import torch.nn as nn +import torch from transformers import AutoModel, AutoTokenizer @@ -16,7 +17,7 @@ def forward(self, text, modality): padding=True, return_tensors='pt', max_length=64, - ) + ).to(device=torch.cuda.current_device()) text_feature = self.model(**encoded).last_hidden_state[:, 0, :] modality_feature = self.modality_embed(modality)