Skip to content

Commit

Permalink
Merge branch 'maisi' of https://github.com/Can-Zhao/model-zoo into maisi
Browse files Browse the repository at this point in the history
  • Loading branch information
Can-Zhao committed Feb 12, 2025
2 parents dbe8612 + 7099ceb commit b0ca42c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 15 deletions.
8 changes: 6 additions & 2 deletions models/maisi_ct_generative/configs/inference.json
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@
"num_res_blocks": 2,
"use_flash_attention": true,
"conditioning_embedding_in_channels": 8,
"conditioning_embedding_num_channels": [8, 32, 64],
"conditioning_embedding_num_channels": [
8,
32,
64
],
"num_class_embeds": 128,
"resblock_updown": true,
"include_fc": true
Expand Down Expand Up @@ -247,7 +251,7 @@
"use_discrete_timesteps": false,
"use_timestep_transform": true,
"sample_method": "logit-normal",
"scale":1.2
"scale": 1.2
},
"mask_generation_noise_scheduler": {
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
Expand Down
14 changes: 7 additions & 7 deletions models/maisi_ct_generative/scripts/rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
f" maximal {self.num_train_timesteps} timesteps."
)

self.num_inference_steps = num_inference_steps
self.num_inference_steps = num_inference_steps
# prepare timesteps
timesteps = [(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)]
timesteps = [(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)]
if self.use_discrete_timesteps:
timesteps = [int(round(t)) for t in timesteps]
if self.use_timestep_transform:
timesteps = [timestep_transform(t, input_img_size=input_img_size, base_img_size=base_img_size, num_train_timesteps=self.num_train_timesteps) for t in timesteps]
timesteps = np.array(timesteps).astype(np.float16)
timesteps = np.array(timesteps).astype(np.float16)
if self.use_discrete_timesteps:
timesteps = timesteps.astype(np.int64)
timesteps = timesteps.astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += self.steps_offset
print(self.timesteps)
Expand All @@ -119,12 +119,12 @@ def sample_timesteps(self, x_start):
t = t.long()

if self.use_timestep_transform:
input_img_size = torch.prod(torch.tensor(x_start.shape[-3:]))
input_img_size = torch.prod(torch.tensor(x_start.shape[-3:]))
base_img_size = 32*32*32
t = timestep_transform(t, input_img_size=input_img_size, base_img_size=base_img_size, num_train_timesteps=self.num_train_timesteps)

return t

def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep = None) -> tuple[torch.Tensor, Any]:
"""
Predict the sample at the previous timestep. Core function to propagate the diffusion
Expand Down
6 changes: 3 additions & 3 deletions models/maisi_ct_generative/scripts/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"ct":1,
"ct_wo_contrast":2,
"ct_contrast":3,
"mri":8,
"mri":8,
"mri_t1":9,
"mri_t2":10,
"mri_flair":11,
Expand Down Expand Up @@ -248,7 +248,7 @@ def ldm_conditional_sample_one_image(
guidance_scale = 0 # API for classifier-free guidence, not used in this version
all_next_timesteps = torch.cat((noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)))
for t, next_t in tqdm(zip(noise_scheduler.timesteps, all_next_timesteps), total=min(len(noise_scheduler.timesteps), len(all_next_timesteps))):
timesteps = torch.Tensor((t,)).to(device)
timesteps = torch.Tensor((t,)).to(device)
if guidance_scale == 0:
down_block_res_samples, mid_block_res_sample = controlnet(
x=latents, timesteps=timesteps, controlnet_cond=controlnet_cond_vis,
Expand All @@ -261,7 +261,7 @@ def ldm_conditional_sample_one_image(
class_labels = modality_tensor,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)
)
else:
down_block_res_samples, mid_block_res_sample = controlnet(
x=torch.cat([latents] * 2), timesteps=torch.cat([timesteps] * 2), controlnet_cond=torch.cat([controlnet_cond_vis] * 2),
Expand Down
5 changes: 2 additions & 3 deletions models/maisi_ct_generative/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,15 +682,14 @@ def dynamic_infer(inferer, model, images):
# Extract the spatial dimensions from the images tensor (H, W, D)
spatial_dims = images.shape[2:]
orig_roi = inferer.roi_size

# Check that roi has the same number of dimensions as spatial_dims
if len(orig_roi) != len(spatial_dims):
raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).")

# Iterate and adjust each ROI dimension
adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)]
inferer.roi_size = adjusted_roi
output = inferer(network=model, inputs=images)
inferer.roi_size = orig_roi
return output

0 comments on commit b0ca42c

Please sign in to comment.