From ab2792b10efdd4e488129303bcb67812802ce050 Mon Sep 17 00:00:00 2001 From: ziheng-zhao <565295081@qq.com> Date: Thu, 5 Sep 2024 10:55:44 +0800 Subject: [PATCH 1/3] fix bugs --- inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) mode change 100644 => 100755 inference.py diff --git a/inference.py b/inference.py old mode 100644 new mode 100755 index e4e34c3..78c9a0f --- a/inference.py +++ b/inference.py @@ -35,7 +35,7 @@ def set_seed(config): def main(args): # set gpu if args.gpu: - os.environ['CUDA_VISIBLE_DEVICES'] = args.pin_memory + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) device=torch.device("cuda", int(os.environ["LOCAL_RANK"])) From 57552d6e1ff716349c263cb9c44d206b262ba016 Mon Sep 17 00:00:00 2001 From: ziheng-zhao <565295081@qq.com> Date: Sun, 13 Oct 2024 12:02:33 +0800 Subject: [PATCH 2/3] use online crop in evaluation --- data/evaluate_dataset.py | 177 ++++++++++++++++++++++++++++++++++++++- evaluate.py | 4 +- 2 files changed, 178 insertions(+), 3 deletions(-) diff --git a/data/evaluate_dataset.py b/data/evaluate_dataset.py index 3ea49ca..fffa29e 100644 --- a/data/evaluate_dataset.py +++ b/data/evaluate_dataset.py @@ -2,6 +2,7 @@ import random import json import traceback +import math from einops import rearrange, repeat, reduce import numpy as np @@ -22,7 +23,181 @@ def contains(text, key): for k in key: if k in text: return True - return False + return False + +def split_3d(image_tensor, crop_size=[288, 288, 96]): + # C H W D + interval_h, interval_w, interval_d = crop_size[0] // 2, crop_size[1] // 2, crop_size[2] // 2 + split_idx = [] + split_patch = [] + + c, h, w, d = image_tensor.shape + h_crop = max(math.ceil(h / interval_h) - 1, 1) + w_crop = max(math.ceil(w / interval_w) - 1, 1) + d_crop = max(math.ceil(d / interval_d) - 1, 1) + + for i in range(h_crop): + h_s = i * interval_h + h_e = h_s + crop_size[0] + if h_e > h: + h_s = h - crop_size[0] + h_e = h + if h_s < 0: + h_s = 0 + for j in range(w_crop): + w_s = j * interval_w + w_e = w_s + crop_size[1] + if w_e > w: + w_s = w - crop_size[1] + w_e = w + if w_s < 0: + w_s = 0 + for k in range(d_crop): + d_s = k * interval_d + d_e = d_s + crop_size[2] + if d_e > d: + d_s = d - crop_size[2] + d_e = d + if d_s < 0: + d_s = 0 + split_idx.append([h_s, h_e, w_s, w_e, d_s, d_e]) + split_patch.append(image_tensor[:, h_s:h_e, w_s:w_e, d_s:d_e]) + + return split_patch, split_idx + +class Evaluate_Dataset_OnlineCrop(Dataset): + def __init__(self, jsonl_file, max_queries=256, batch_size=2, patch_size=[288, 288, 96], evaluated_samples=set()): + """ + max_queries: num of queries in a batch. can be very large. + batch_size: num of image patch in a batch. be careful with this if you have limited gpu memory. + evaluated_samples: to resume from an interrupted evaluation + """ + # load data info + self.jsonl_file = jsonl_file + with open(self.jsonl_file, 'r') as f: + lines = f.readlines() + lines = [json.loads(line) for line in lines] + + self.lines = [] + + for sample in lines: + # if resume and inherit medial results another evaluation + sample_id = sample['renorm_image'].split('/')[-1][:-4] # abcd/x.npy --> x + dataset_name = sample['dataset'] + if f'{dataset_name}_{sample_id}' not in evaluated_samples: + self.lines.append(sample) + + self.max_queries = max_queries + self.batch_size = batch_size + self.patch_size = patch_size + + if is_master(): + print(f'** Online Crop DATASET ** : Skip {len(lines)-len(self.lines)} samples, {len(self.lines)} to be evaluated') + print(f'** Online Crop DATASET ** : Maximum {self.max_queries} queries, patch size {self.patch_size}') + + def __len__(self): + return len(self.lines) + + def _split_labels(self, label_list): + # split the labels into sub-lists + if len(label_list) < self.max_queries: + return [label_list], [[0, len(label_list)]] + else: + split_idx = [] + split_label = [] + query_num = len(label_list) + n_crop = (query_num // self.max_queries + 1) if (query_num % self.max_queries != 0) else (query_num // self.max_queries) + for n in range(n_crop): + n_s = n*self.max_queries + n_f = min((n+1)*self.max_queries, query_num) + split_label.append(label_list[n_s:n_f]) + split_idx.append([n_s, n_f]) + return split_label, split_idx + + def _merge_modality(self, mod): + if contains(mod, ['t1', 't2', 'mri', 'mr', 'flair', 'dwi']): + return 'mri' + if contains(mod, 'ct'): + return 'ct' + if contains(mod, 'pet'): + return 'pet' + else: + return mod + + def _pad_if_necessary(self, patch): + # NOTE: depth must be pad to 96 + b, c, h, w, d = patch.shape + t_h, t_w, t_d = self.patch_size + pad_in_h = 0 if h >= t_h else t_h - h + pad_in_w = 0 if w >= t_w else t_w - w + pad_in_d = 0 if d >= t_d else t_d - d + if pad_in_h + pad_in_w + pad_in_d > 0: + pad = (0, pad_in_d, 0, pad_in_w, 0, pad_in_h) + patch = F.pad(patch, pad, 'constant', 0) # chwd + return patch + + def __getitem__(self, idx): + datum = self.lines[idx] + sample_id = datum['renorm_image'].split('/')[-1][:-4] # abcd/x.npy --> x + + # image to patches + img = torch.tensor(np.load(datum['renorm_image'])) + + patches, y1y2_x1x2_z1z2_ls = split_3d(img, crop_size=self.patch_size) + + # divide patches into batches + batch_num = len(patches) // self.batch_size if len(patches) % self.batch_size == 0 else len(patches) // self.batch_size + 1 + batched_patches = [] + batched_y1y2_x1x2_z1z2 = [] + for i in range(batch_num): + srt = i*self.batch_size + end = min(i*self.batch_size+self.batch_size, len(patches)) + patch = torch.stack([patches[j] for j in range(srt, end)], dim=0) + # NOTE: depth must be pad to 96 + patch = self._pad_if_necessary(patch) + # for single-channel images, e.g. mri and ct, pad to 3 + # repeat sc image to mc + if patch.shape[1] == 1: + patch = repeat(patch, 'b c h w d -> b (c r) h w d', r=3) + batched_patches.append(patch) # b, *patch_size + batched_y1y2_x1x2_z1z2.append([y1y2_x1x2_z1z2_ls[j] for j in range(srt, end)]) + + # split labels into batches + labels = datum['label'] + split_labels, split_n1n2 = self._split_labels(labels) # [xxx, ...] [[n1, n2], ...] + modality = datum['modality'] + modality = self._merge_modality(modality.lower()) + for i in range(len(split_labels)): + split_labels[i] = [label.lower() for label in split_labels[i]] + + # load gt segmentations + c,h,w,d = datum['chwd'] + # labels = [datum['label'][i] for i in datum['visible_label_idx']] # laryngeal cancer or hypopharyngeal cancer + mask_paths = [f"{datum['renorm_segmentation_dir']}/{label}.npy" for label in labels] + y1x1z1_y2x2z2_ls = datum['renorm_y1x1z1_y2x2z2'] # [datum['renorm_y1x1z1_y2x2z2'][i] for i in datum['visible_label_idx']] + + mc_mask = [] + for mask_path, y1x1z1_y2x2z2 in zip(mask_paths, y1x1z1_y2x2z2_ls): + mask = torch.zeros((h, w, d)) + # not empty, load and embed non-empty cropped_volume + if y1x1z1_y2x2z2 != False: + y1, x1, z1, y2, x2, z2 = y1x1z1_y2x2z2 + mask[y1:y2, x1:x2, z1:z2] = torch.tensor(np.load(mask_path)) + mc_mask.append(mask.float()) + mc_mask = torch.stack(mc_mask, dim=0) # n h w d + + return { + 'dataset_name':datum['dataset'], + 'sample_id':sample_id, + 'batched_patches':batched_patches, + 'batched_y1y2_x1x2_z1z2':batched_y1y2_x1x2_z1z2, + 'split_labels':split_labels, + 'modality':modality, + 'split_n1n2':split_n1n2, + 'gt_segmentation':mc_mask, + 'labels':labels, + 'image_path':datum['renorm_image'] + } class Evaluate_Dataset(Dataset): def __init__(self, jsonl_file, max_queries=256, batch_size=2, patch_size=[288, 288, 96], evaluated_samples=set()): diff --git a/evaluate.py b/evaluate.py index 6820f77..0e49269 100644 --- a/evaluate.py +++ b/evaluate.py @@ -11,7 +11,7 @@ from pathlib import Path import torch.distributed as dist -from data.evaluate_dataset import Evaluate_Dataset, collate_fn +from data.evaluate_dataset import Evaluate_Dataset, Evaluate_Dataset_OnlineCrop, collate_fn from model.build_model import build_maskformer, load_checkpoint from model.text_encoder import Text_Encoder from evaluate.evaluator import evaluate @@ -67,7 +67,7 @@ def main(args): evaluated_samples.add(f'{line[0]}_{line[2]}') # dataset and loader - testset = Evaluate_Dataset(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples) + testset = Evaluate_Dataset_OnlineCrop(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples) sampler = DistributedSampler(testset) testloader = DataLoader(testset, sampler=sampler, batch_size=1, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn, shuffle=False) sampler.set_epoch(0) From 7a519f6062f03863243e98c700d5b68b6a110519 Mon Sep 17 00:00:00 2001 From: ziheng-zhao <565295081@qq.com> Date: Wed, 6 Nov 2024 11:06:43 +0800 Subject: [PATCH 3/3] update the evaluation guidance --- README.md | 14 ++++++------ evaluate.py | 9 ++++---- evaluate/evaluator.py | 13 ++++++----- evaluate/params.py | 18 ++++++--------- sh/evaluate_sat_pro.sh | 50 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 76 insertions(+), 28 deletions(-) create mode 100755 sh/evaluate_sat_pro.sh diff --git a/README.md b/README.md index 0e79c67..1aac38f 100644 --- a/README.md +++ b/README.md @@ -98,19 +98,19 @@ The input image should be with shape `H,W,D` Our data process code will normaliz ## Train Guidance: Some preparation before start the training: - 1. you need to build your training data following this [repo](https://github.com/zhaoziheng/SAT-DS/tree/main), a jsonl containing all the training samples is required. + 1. you need to build your training data following this [repo](https://github.com/zhaoziheng/SAT-DS/tree/main), specifically, from step 1 to step 5. A jsonl containing all the training samples is required. 2. you need to fetch the text encoder checkpoint from https://huggingface.co/zzh99/SAT to generate prompts. Our recommendation for training SAT-Nano is 8 or more A100-80G, for SAT-Pro is 16 or more A100-80G. Please use the slurm script in `sh/` to start the training process. Take SAT-Pro for example: ``` sbatch sh/train_sat_pro.sh ``` - - -## TODO -- [ ] Inference demo on website. -- [x] Release the data preprocess code to build SAT-DS. -- [x] Release the train guidance. +## Evaluation Guidance: +This also requires to build test data following this [repo](https://github.com/zhaoziheng/SAT-DS/tree/main). +You may refer to the slurm script `sh/evaluate_sat_pro.sh` to start the evaluation process: + ``` + sbatch sh/evaluate_sat_pro.sh + ``` ## Citation If you use this code for your research or project, please cite: diff --git a/evaluate.py b/evaluate.py index 0e49269..a804f09 100644 --- a/evaluate.py +++ b/evaluate.py @@ -67,7 +67,10 @@ def main(args): evaluated_samples.add(f'{line[0]}_{line[2]}') # dataset and loader - testset = Evaluate_Dataset_OnlineCrop(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples) + if args.online_crop: + testset = Evaluate_Dataset_OnlineCrop(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples) + else: + testset = Evaluate_Dataset(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples) sampler = DistributedSampler(testset) testloader = DataLoader(testset, sampler=sampler, batch_size=1, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn, shuffle=False) sampler.set_epoch(0) @@ -106,9 +109,7 @@ def main(args): save_interval=args.save_interval, dice_score=args.dice, nsd_score=args.nsd, - visualization=args.visualization, - region_split_json=args.region_split_json, - label_statistic_json=args.label_statistic_json) + visualization=args.visualization) if __name__ == '__main__': # get configs diff --git a/evaluate/evaluator.py b/evaluate/evaluator.py index 8e945a7..7961f0d 100644 --- a/evaluate/evaluator.py +++ b/evaluate/evaluator.py @@ -46,9 +46,7 @@ def evaluate(model, csv_path, resume, save_interval, - visualization, - region_split_json, - label_statistic_json): + visualization): # if to store pred、gt、img (as nii.gz if visualization: @@ -376,13 +374,16 @@ def evaluate(model, name = name[len(name)-31:] df.to_excel(writer, sheet_name=name, index=True) - avg_dice_over_merged_labels, avg_nsd_over_merged_labels = merge(region_split_json, label_statistic_json, xlsx_path, xlsx_path) + # avg_dice_over_merged_labels, avg_nsd_over_merged_labels = merge(region_split_json, label_statistic_json, xlsx_path, xlsx_path) os.remove(csv_path.replace('.csv', '.pkl')) else: - avg_dice_over_merged_labels = avg_nsd_over_merged_labels = 0 - return avg_dice_over_merged_labels, avg_nsd_over_merged_labels + pass + + # avg_dice_over_merged_labels = avg_nsd_over_merged_labels = 0 + + return # avg_dice_over_merged_labels, avg_nsd_over_merged_labels \ No newline at end of file diff --git a/evaluate/params.py b/evaluate/params.py index f37a1a6..f63cfcb 100644 --- a/evaluate/params.py +++ b/evaluate/params.py @@ -64,16 +64,6 @@ def parse_args(): type=str2bool, default=True, ) - parser.add_argument( - "--region_split_json", - type=str, - default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab(72).json', - ) - parser.add_argument( - "--label_statistic_json", - type=str, - default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab_accum_statis(72).json', - ) # Med SAM Dataset @@ -84,6 +74,12 @@ def parse_args(): # Sampler and Loader + parser.add_argument( + "--online_crop", + type=str2bool, + default='False', + help='load pre-cropped image patches directly, or crop online', + ) parser.add_argument( "--crop_size", type=int, @@ -133,7 +129,7 @@ def parse_args(): parser.add_argument( "--vision_backbone", type=str, - help='UNETs UMamba or SwinUNETR' + help='UNET UNET-L UMamba or SwinUNETR' ) parser.add_argument( "--patch_size", diff --git a/sh/evaluate_sat_pro.sh b/sh/evaluate_sat_pro.sh new file mode 100755 index 0000000..f7871b4 --- /dev/null +++ b/sh/evaluate_sat_pro.sh @@ -0,0 +1,50 @@ +#!/bin/bash +#SBATCH --job-name=eval_pro +#SBATCH --quotatype=auto +#SBATCH --partition=medai +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem-per-cpu=128G +#SBATCH --chdir=/mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/log/sbatch +#SBATCH --output=/mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/log/sbatch/%x-%j.out +#SBATCH --error=/mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/log/sbatch/%x-%j.error +###SBATCH -w SH-IDC1-10-140-0-[...], SH-IDC1-10-140-1-[...] +###SBATCH -x SH-IDC1-10-140-0-[...], SH-IDC1-10-140-1-[...] + +export NCCL_DEBUG=INFO +export NCCL_IBEXT_DISABLE=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=eth0 +echo NODELIST=${SLURM_NODELIST} +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=$master_addr +MASTER_PORT=$((RANDOM % 101 + 20000)) +echo "MASTER_ADDR="$MASTER_ADDR + +srun torchrun \ +--nnodes 1 \ +--nproc_per_node 1 \ +--rdzv_id 100 \ +--rdzv_backend c10d \ +--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT /mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/evaluate.py \ +--rcd_dir 'your_rcd_dir' \ +--rcd_file 'your_rcd_file_name' \ +--resume False \ +--visualization False \ +--deep_supervision False \ +--datasets_jsonl 'jsonl generated from SAT-DS Step 4' \ +--crop_size 288 288 96 \ +--online_crop True \ +--vision_backbone 'UNET-L' \ +--checkpoint 'your ckpt' \ +--partial_load True \ +--text_encoder 'ours' \ +--text_encoder_checkpoint 'your text encoder ckpt' \ +--batchsize_3d 2 \ +--max_queries 256 \ +--pin_memory False \ +--num_workers 4 \ +--dice True \ +--nsd True \ No newline at end of file