From 3d8fedee01ca2ba9a4d8cd824d0b817830f9038d Mon Sep 17 00:00:00 2001 From: ziheng-zhao <565295081@qq.com> Date: Fri, 24 May 2024 13:57:17 +0800 Subject: [PATCH] fix bugs in medcpt and basebert --- docs/.DS_Store | Bin 8196 -> 8196 bytes evaluate/inference_engine.py | 2 ++ model/base_bert.py | 3 ++- model/maskformer.py | 2 ++ model/med_cpt.py | 3 ++- 5 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/.DS_Store b/docs/.DS_Store index 6efa9d954762e57f1b46f36026b2995cbd5d4d11..24828a13aa10449f854fe017756db8bb79047826 100644 GIT binary patch literal 8196 zcmeHMU2GIp6h3EL;NC#n0m{-67FLV=*i{#V7AV1OyDbz%lD z%soFd_kQ!8xqD^-0K1C%T7V`1U~~wn&84nN6LxW2(?Y(6lSGOKC_(|UV1Sz`U*jDp z0t5mC0t5mC0t5mCt_K8Y&lZi%udS*5buX59YQ92IjN+4bxG{j|z8^313bsslowOIG}%L^eYsUcc-|h zbO%f*8T0`H0Rq<|AZm9`RXf$};9Ob1yZMaeWXO6ZlebLQi^tDnsHS$->^W=>o6GJ^ zjXUL(m-dQoSGq7NM*WUsWh&?C%%EwM$F#<7$M({uZ4_jHsT<_7Z^*QDr`+WfZQYfz zDP>@s@iDFG=+XB0##nUS#`P1i=+RB<;<4!Zw)GPeoGog(tz*yFAuDG)PpHM9|3ZMe zX^VTZ`pR@SYh`Gr!n?}Ct%GY*Stpm57d!si@fp0_Qn!L9bg2T*=M{K3pBvtj>K_=> zw4!V5F|4#_mW*WDGu$z)!6l^nvW|Us!6*?zLs`cv=4`Ue&!(+xk(`ZA%SxFK81ArV z9Pz~2i0c%*Zs|B|QmoL8X$wUB?!0F@wmWDPT#}rhp0{Yp@|M-NtZUzrNcQbIS6APl zalWu|sARfk#xnM%-K=5jrhRarZ0g?l2mx<<<^j_va6TUk{dPKMoU3c7*TRd#%}i@y z!^MnwxM+I!t80ZL%Q(|c?$c?ayBu9a>e3PpM{Z!j2g>g#%X!{HHkJY6t$mx-6!r;JJc4${j9L>6(K z9^Leuf^WP_8a2}{nM8VmT5+E-){?kck_9AmEs3}qH7YUHiw&+BVz3o@VT3-C*-(b3 z;AuDpFT<;F8qUDm@Bw@XAHgT^8GHp_!v(kq-@*^@BV2-?u?FX19oFL#Y{37xt zv-mMyz>D||euux}W&AA!As(s^EmpG*RZ?qkY8CP%w3gjNnRKjcc$~(z6{m;%4l?B=!9~&$(S+Zmk}`xkS&oyG z|IYW%^xrbZ&{;IrG1kBs<8%nAO{b~C5O#6gG?Y*?jU*_Zu_7z5JTsY>o4Cn4QbY(u z2t)`(2t)`(2;2?`&^KEQev9wEv_^A;K!m{m5&`jkNYEieHK~I^ z=hO!TF^N#7LOCgA5K7mS-2=i&Llq+w%Nm@+&A+)zfNfc`cvL52Q!k ziL{^fi(Y58uwSftUDwW)&$GFH%bdt)^sJRe+ME()w`bh9T*vD-3m(ag zHrpc3CADCptMlDDJ!~ z>-l@eO^0qC6ruGHrfqpmXgIy$xxK6iW z`EDUJJ`xx;%PCpJdfIPEe$IJgYNgzq0Mb}vDjDt-XF6L$->)7vLiIJyPO|6ODfSjS z!#-jc*hTg^yUe~}-?H!759|v2js1ZtRHFv9(4b=>?m`?Z(SlaAp#u+LCwAdsq%nd6 zIEX_q;lRZ>3OI%*@f1$t8N7hk@H*bWn|Kdra26ln0zSn>e1k_yt$+yHcan zDqN{kmMY7XCgolwskABW^0}3AMGZ~moU(=1l}jdkS15L+l8JY3Q+sZUlDRxX6iMBj zx%KlFH!fd&U;Doi*YtxAPa=BD(ifc8*d%0W@UlrFBDmS@Y6HlsI6%|t~h;50lQZ;b7xL@6% zYSmO$a=A&}tZIx3T`oJsg@~;Go~`HDdG-mh^%Al5N8;(P>@RT4LIW0~3Cpnp%~*|g ztj7jy#3poOJ9+}P_M#sH7{V~}#MQ$Xg+*NTQA7z(1dM%}82ch#!pnFCui_Nm#yfZy zr|~{M#JL;TxMM0CalSGeOX-4RyUr1^mU%oiFj|!nsFEiQqW|x{_3!^>_ee$vLocB>br=(Q7c9HK)E;WsH|5JHo15&EDVCwcWBhIF0?ZK+UBN=X(P|M?FA V(f2>Vo&D3lR-^B~#On^5{|54oIk*4- diff --git a/evaluate/inference_engine.py b/evaluate/inference_engine.py index c7a8676..e270b15 100644 --- a/evaluate/inference_engine.py +++ b/evaluate/inference_engine.py @@ -88,6 +88,8 @@ def inference(model, text_encoder, device, testset, testloader, nib_dir): queries_ls = [] for labels_ls, n1n2 in zip(split_labels, split_n1n2): # convert list of texts to list of embeds queries_ls.append(text_encoder(labels_ls, modality)) + + torch.cuda.empty_cache() # for each batch of patches, query with all labels for patches, y1y2_x1x2_z1z2_ls in zip(batched_patches, batched_y1y2_x1x2_z1z2): # [b, c, h, w, d] diff --git a/model/base_bert.py b/model/base_bert.py index f18f926..96edabc 100644 --- a/model/base_bert.py +++ b/model/base_bert.py @@ -1,4 +1,5 @@ import torch.nn as nn +import torch from transformers import BertModel, AutoTokenizer @@ -16,7 +17,7 @@ def forward(self, text, modality): padding=True, return_tensors='pt', max_length=64, - ) + ).to(device=torch.cuda.current_device()) text_feature = self.model(**encoded).last_hidden_state[:, 0, :] modality_feature = self.modality_embed(modality) diff --git a/model/maskformer.py b/model/maskformer.py index 392514b..8e279ea 100644 --- a/model/maskformer.py +++ b/model/maskformer.py @@ -323,6 +323,8 @@ def forward(self, queries, image_input): # Infer / Evaluate Forward ------------------------------------------------------------ if isinstance(queries, List): + del image_input + torch.cuda.empty_cache() logits = self.infer_forward(queries, image_embedding, pos, per_pixel_embedding_ls) # Train Forward ----------------------------------------------------------------------- diff --git a/model/med_cpt.py b/model/med_cpt.py index 49fade5..1d3d4b7 100644 --- a/model/med_cpt.py +++ b/model/med_cpt.py @@ -1,4 +1,5 @@ import torch.nn as nn +import torch from transformers import AutoModel, AutoTokenizer @@ -16,7 +17,7 @@ def forward(self, text, modality): padding=True, return_tensors='pt', max_length=64, - ) + ).to(device=torch.cuda.current_device()) text_feature = self.model(**encoded).last_hidden_state[:, 0, :] modality_feature = self.modality_embed(modality)