Skip to content

Commit

Permalink
Refactor ou_viz.py: optimize EDT input processing and visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
dummyindex committed Jun 4, 2024
1 parent b525107 commit c10e01b
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions livecellx/segment/ou_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,24 @@ def viz_ou_outputs(
title=None,
input_type="raw_aug_duplicate",
edt_mask=None,
edt_transform=None,
) -> Tuple:
original_shape = augmented_ou_crop.shape
original_ou_input = augmented_ou_crop.copy()
ou_input = input_transforms(torch.tensor([augmented_ou_crop]))
augmented_ou_crop = torch.tensor([augmented_ou_crop])
augmented_ou_crop = input_transforms(augmented_ou_crop).squeeze()
ou_input = None
if input_type == "raw_aug_duplicate":
ou_input = torch.stack([ou_input, ou_input, ou_input], dim=1)
ou_input = torch.stack([augmented_ou_crop, augmented_ou_crop, augmented_ou_crop], dim=0)
elif input_type == "edt_v0":
# normalize_edt(augmented_scaled_seg_mask, edt_max=4)
assert edt_mask is not None
ou_input = torch.stack([ou_input, ou_input, edt_mask], dim=1)
assert edt_mask is not None and edt_transform is not None
# Transform edt_mask to tensor
edt_mask = torch.tensor([edt_mask]).squeeze().unsqueeze(0)
edt_mask = edt_transform(edt_mask).squeeze()
ou_input = torch.stack([augmented_ou_crop, augmented_ou_crop, edt_mask], dim=0)

ou_input = ou_input.unsqueeze(0) # For batch size 1
ou_input = ou_input.float().cuda()
if has_aux:
seg_output, aux_output = model(ou_input)
Expand Down Expand Up @@ -122,11 +129,12 @@ def viz_ou_outputs(

# visualize the input and all 3 output channels
if show or (save_path is not None):
total_figs = 7
if original_img is not None:
num_figures = 8
else:
num_figures = 7
fig, axes = plt.subplots(1, num_figures, figsize=(15, 5))
total_figs += 1
if input_type == "edt_v0":
total_figs += 1
fig, axes = plt.subplots(1, total_figs, figsize=(15, 5))
axes[0].imshow(original_ou_input)
axes[0].set_title("input")
axes[1].imshow(seg_output[0, 0].cpu().detach().numpy())
Expand All @@ -141,9 +149,16 @@ def viz_ou_outputs(
axes[5].set_title("output c0 > 1")
axes[6].imshow(watershed_mask)
axes[6].set_title("watershed mask")

pos = 6
if original_img is not None:
pos += 1
axes[7].imshow(enhance_contrast(normalize_img_to_uint8(original_img)))
axes[7].set_title("original img")
if input_type == "edt_v0":
pos += 1
axes[pos].imshow(edt_mask)
axes[pos].set_title("edt_mask")
if title:
plt.suptitle(title)
if show:
Expand Down

0 comments on commit c10e01b

Please sign in to comment.