Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/zhaoziheng/SAT
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoziheng committed Dec 24, 2024
2 parents b632222 + 7a519f6 commit f77fc20
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 25 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,20 @@ The input image should be with shape `H,W,D` Our data process code will normaliz
## Train Guidance:
Some preparation before start the training:
1. you need to build your training data following this [repo](https://github.com/zhaoziheng/SAT-DS/tree/main), a jsonl containing all the training samples is required.
1. you need to build your training data following this [repo](https://github.com/zhaoziheng/SAT-DS/tree/main), specifically, from step 1 to step 5. A jsonl containing all the training samples is required.
2. you need to fetch the text encoder checkpoint from https://huggingface.co/zzh99/SAT to generate prompts.
Our recommendation for training SAT-Nano is 8 or more A100-80G, for SAT-Pro is 16 or more A100-80G. Please use the slurm script in `sh/` to start the training process. Take SAT-Pro for example:
```
sbatch sh/train_sat_pro.sh
```
## Evaluation Guidance:
This also requires to build test data following this [repo](https://github.com/zhaoziheng/SAT-DS/tree/main).
You may refer to the slurm script `sh/evaluate_sat_pro.sh` to start the evaluation process:
```
sbatch sh/evaluate_sat_pro.sh
```
## Baselines
We provide the detailed configurations of all the specialist models (nnU-Nets, U-Mambas, SwinUNETR) we have trained and evaluated [here](https://github.com/zhaoziheng/SAT-DS/blob/main/data/specialist_model_config).
Expand Down
177 changes: 176 additions & 1 deletion data/evaluate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
import json
import traceback
import math

from einops import rearrange, repeat, reduce
import numpy as np
Expand All @@ -22,7 +23,181 @@ def contains(text, key):
for k in key:
if k in text:
return True
return False
return False

def split_3d(image_tensor, crop_size=[288, 288, 96]):
# C H W D
interval_h, interval_w, interval_d = crop_size[0] // 2, crop_size[1] // 2, crop_size[2] // 2
split_idx = []
split_patch = []

c, h, w, d = image_tensor.shape
h_crop = max(math.ceil(h / interval_h) - 1, 1)
w_crop = max(math.ceil(w / interval_w) - 1, 1)
d_crop = max(math.ceil(d / interval_d) - 1, 1)

for i in range(h_crop):
h_s = i * interval_h
h_e = h_s + crop_size[0]
if h_e > h:
h_s = h - crop_size[0]
h_e = h
if h_s < 0:
h_s = 0
for j in range(w_crop):
w_s = j * interval_w
w_e = w_s + crop_size[1]
if w_e > w:
w_s = w - crop_size[1]
w_e = w
if w_s < 0:
w_s = 0
for k in range(d_crop):
d_s = k * interval_d
d_e = d_s + crop_size[2]
if d_e > d:
d_s = d - crop_size[2]
d_e = d
if d_s < 0:
d_s = 0
split_idx.append([h_s, h_e, w_s, w_e, d_s, d_e])
split_patch.append(image_tensor[:, h_s:h_e, w_s:w_e, d_s:d_e])

return split_patch, split_idx

class Evaluate_Dataset_OnlineCrop(Dataset):
def __init__(self, jsonl_file, max_queries=256, batch_size=2, patch_size=[288, 288, 96], evaluated_samples=set()):
"""
max_queries: num of queries in a batch. can be very large.
batch_size: num of image patch in a batch. be careful with this if you have limited gpu memory.
evaluated_samples: to resume from an interrupted evaluation
"""
# load data info
self.jsonl_file = jsonl_file
with open(self.jsonl_file, 'r') as f:
lines = f.readlines()
lines = [json.loads(line) for line in lines]

self.lines = []

for sample in lines:
# if resume and inherit medial results another evaluation
sample_id = sample['renorm_image'].split('/')[-1][:-4] # abcd/x.npy --> x
dataset_name = sample['dataset']
if f'{dataset_name}_{sample_id}' not in evaluated_samples:
self.lines.append(sample)

self.max_queries = max_queries
self.batch_size = batch_size
self.patch_size = patch_size

if is_master():
print(f'** Online Crop DATASET ** : Skip {len(lines)-len(self.lines)} samples, {len(self.lines)} to be evaluated')
print(f'** Online Crop DATASET ** : Maximum {self.max_queries} queries, patch size {self.patch_size}')

def __len__(self):
return len(self.lines)

def _split_labels(self, label_list):
# split the labels into sub-lists
if len(label_list) < self.max_queries:
return [label_list], [[0, len(label_list)]]
else:
split_idx = []
split_label = []
query_num = len(label_list)
n_crop = (query_num // self.max_queries + 1) if (query_num % self.max_queries != 0) else (query_num // self.max_queries)
for n in range(n_crop):
n_s = n*self.max_queries
n_f = min((n+1)*self.max_queries, query_num)
split_label.append(label_list[n_s:n_f])
split_idx.append([n_s, n_f])
return split_label, split_idx

def _merge_modality(self, mod):
if contains(mod, ['t1', 't2', 'mri', 'mr', 'flair', 'dwi']):
return 'mri'
if contains(mod, 'ct'):
return 'ct'
if contains(mod, 'pet'):
return 'pet'
else:
return mod

def _pad_if_necessary(self, patch):
# NOTE: depth must be pad to 96
b, c, h, w, d = patch.shape
t_h, t_w, t_d = self.patch_size
pad_in_h = 0 if h >= t_h else t_h - h
pad_in_w = 0 if w >= t_w else t_w - w
pad_in_d = 0 if d >= t_d else t_d - d
if pad_in_h + pad_in_w + pad_in_d > 0:
pad = (0, pad_in_d, 0, pad_in_w, 0, pad_in_h)
patch = F.pad(patch, pad, 'constant', 0) # chwd
return patch

def __getitem__(self, idx):
datum = self.lines[idx]
sample_id = datum['renorm_image'].split('/')[-1][:-4] # abcd/x.npy --> x

# image to patches
img = torch.tensor(np.load(datum['renorm_image']))

patches, y1y2_x1x2_z1z2_ls = split_3d(img, crop_size=self.patch_size)

# divide patches into batches
batch_num = len(patches) // self.batch_size if len(patches) % self.batch_size == 0 else len(patches) // self.batch_size + 1
batched_patches = []
batched_y1y2_x1x2_z1z2 = []
for i in range(batch_num):
srt = i*self.batch_size
end = min(i*self.batch_size+self.batch_size, len(patches))
patch = torch.stack([patches[j] for j in range(srt, end)], dim=0)
# NOTE: depth must be pad to 96
patch = self._pad_if_necessary(patch)
# for single-channel images, e.g. mri and ct, pad to 3
# repeat sc image to mc
if patch.shape[1] == 1:
patch = repeat(patch, 'b c h w d -> b (c r) h w d', r=3)
batched_patches.append(patch) # b, *patch_size
batched_y1y2_x1x2_z1z2.append([y1y2_x1x2_z1z2_ls[j] for j in range(srt, end)])

# split labels into batches
labels = datum['label']
split_labels, split_n1n2 = self._split_labels(labels) # [xxx, ...] [[n1, n2], ...]
modality = datum['modality']
modality = self._merge_modality(modality.lower())
for i in range(len(split_labels)):
split_labels[i] = [label.lower() for label in split_labels[i]]

# load gt segmentations
c,h,w,d = datum['chwd']
# labels = [datum['label'][i] for i in datum['visible_label_idx']] # laryngeal cancer or hypopharyngeal cancer
mask_paths = [f"{datum['renorm_segmentation_dir']}/{label}.npy" for label in labels]
y1x1z1_y2x2z2_ls = datum['renorm_y1x1z1_y2x2z2'] # [datum['renorm_y1x1z1_y2x2z2'][i] for i in datum['visible_label_idx']]

mc_mask = []
for mask_path, y1x1z1_y2x2z2 in zip(mask_paths, y1x1z1_y2x2z2_ls):
mask = torch.zeros((h, w, d))
# not empty, load and embed non-empty cropped_volume
if y1x1z1_y2x2z2 != False:
y1, x1, z1, y2, x2, z2 = y1x1z1_y2x2z2
mask[y1:y2, x1:x2, z1:z2] = torch.tensor(np.load(mask_path))
mc_mask.append(mask.float())
mc_mask = torch.stack(mc_mask, dim=0) # n h w d

return {
'dataset_name':datum['dataset'],
'sample_id':sample_id,
'batched_patches':batched_patches,
'batched_y1y2_x1x2_z1z2':batched_y1y2_x1x2_z1z2,
'split_labels':split_labels,
'modality':modality,
'split_n1n2':split_n1n2,
'gt_segmentation':mc_mask,
'labels':labels,
'image_path':datum['renorm_image']
}

class Evaluate_Dataset(Dataset):
def __init__(self, jsonl_file, max_queries=256, batch_size=2, patch_size=[288, 288, 96], evaluated_samples=set()):
Expand Down
11 changes: 6 additions & 5 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pathlib import Path
import torch.distributed as dist

from data.evaluate_dataset import Evaluate_Dataset, collate_fn
from data.evaluate_dataset import Evaluate_Dataset, Evaluate_Dataset_OnlineCrop, collate_fn
from model.build_model import build_maskformer, load_checkpoint
from model.text_encoder import Text_Encoder
from evaluate.evaluator import evaluate
Expand Down Expand Up @@ -67,7 +67,10 @@ def main(args):
evaluated_samples.add(f'{line[0]}_{line[2]}')

# dataset and loader
testset = Evaluate_Dataset(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples)
if args.online_crop:
testset = Evaluate_Dataset_OnlineCrop(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples)
else:
testset = Evaluate_Dataset(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples)
sampler = DistributedSampler(testset)
testloader = DataLoader(testset, sampler=sampler, batch_size=1, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn, shuffle=False)
sampler.set_epoch(0)
Expand Down Expand Up @@ -106,9 +109,7 @@ def main(args):
save_interval=args.save_interval,
dice_score=args.dice,
nsd_score=args.nsd,
visualization=args.visualization,
region_split_json=args.region_split_json,
label_statistic_json=args.label_statistic_json)
visualization=args.visualization)

if __name__ == '__main__':
# get configs
Expand Down
13 changes: 7 additions & 6 deletions evaluate/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def evaluate(model,
csv_path,
resume,
save_interval,
visualization,
region_split_json,
label_statistic_json):
visualization):

# if to store pred、gt、img (as nii.gz
if visualization:
Expand Down Expand Up @@ -376,13 +374,16 @@ def evaluate(model,
name = name[len(name)-31:]
df.to_excel(writer, sheet_name=name, index=True)

avg_dice_over_merged_labels, avg_nsd_over_merged_labels = merge(region_split_json, label_statistic_json, xlsx_path, xlsx_path)
# avg_dice_over_merged_labels, avg_nsd_over_merged_labels = merge(region_split_json, label_statistic_json, xlsx_path, xlsx_path)

os.remove(csv_path.replace('.csv', '.pkl'))

else:
avg_dice_over_merged_labels = avg_nsd_over_merged_labels = 0

return avg_dice_over_merged_labels, avg_nsd_over_merged_labels
pass

# avg_dice_over_merged_labels = avg_nsd_over_merged_labels = 0

return # avg_dice_over_merged_labels, avg_nsd_over_merged_labels


18 changes: 7 additions & 11 deletions evaluate/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,6 @@ def parse_args():
type=str2bool,
default=True,
)
parser.add_argument(
"--region_split_json",
type=str,
default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab(72).json',
)
parser.add_argument(
"--label_statistic_json",
type=str,
default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab_accum_statis(72).json',
)

# Med SAM Dataset

Expand All @@ -84,6 +74,12 @@ def parse_args():

# Sampler and Loader

parser.add_argument(
"--online_crop",
type=str2bool,
default='False',
help='load pre-cropped image patches directly, or crop online',
)
parser.add_argument(
"--crop_size",
type=int,
Expand Down Expand Up @@ -133,7 +129,7 @@ def parse_args():
parser.add_argument(
"--vision_backbone",
type=str,
help='UNETs UMamba or SwinUNETR'
help='UNET UNET-L UMamba or SwinUNETR'
)
parser.add_argument(
"--patch_size",
Expand Down
2 changes: 1 addition & 1 deletion inference.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def set_seed(config):
def main(args):
# set gpu
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.pin_memory
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
device=torch.device("cuda", int(os.environ["LOCAL_RANK"]))
Expand Down
50 changes: 50 additions & 0 deletions sh/evaluate_sat_pro.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/bin/bash
#SBATCH --job-name=eval_pro
#SBATCH --quotatype=auto
#SBATCH --partition=medai
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=16
#SBATCH --mem-per-cpu=128G
#SBATCH --chdir=/mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/log/sbatch
#SBATCH --output=/mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/log/sbatch/%x-%j.out
#SBATCH --error=/mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/log/sbatch/%x-%j.error
###SBATCH -w SH-IDC1-10-140-0-[...], SH-IDC1-10-140-1-[...]
###SBATCH -x SH-IDC1-10-140-0-[...], SH-IDC1-10-140-1-[...]

export NCCL_DEBUG=INFO
export NCCL_IBEXT_DISABLE=1
export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=eth0
echo NODELIST=${SLURM_NODELIST}
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
MASTER_PORT=$((RANDOM % 101 + 20000))
echo "MASTER_ADDR="$MASTER_ADDR

srun torchrun \
--nnodes 1 \
--nproc_per_node 1 \
--rdzv_id 100 \
--rdzv_backend c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT /mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/evaluate.py \
--rcd_dir 'your_rcd_dir' \
--rcd_file 'your_rcd_file_name' \
--resume False \
--visualization False \
--deep_supervision False \
--datasets_jsonl 'jsonl generated from SAT-DS Step 4' \
--crop_size 288 288 96 \
--online_crop True \
--vision_backbone 'UNET-L' \
--checkpoint 'your ckpt' \
--partial_load True \
--text_encoder 'ours' \
--text_encoder_checkpoint 'your text encoder ckpt' \
--batchsize_3d 2 \
--max_queries 256 \
--pin_memory False \
--num_workers 4 \
--dice True \
--nsd True

0 comments on commit f77fc20

Please sign in to comment.