Skip to content

Commit

Permalink
update eval_csn.py, sc_correction_dataset.py and train_csn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dummyindex committed Mar 4, 2025
1 parent 4ba6b57 commit a842cdd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 2 additions & 0 deletions livecellx/model_zoo/segmentation/eval_csn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def assemble_dataset(
input_type=None,
use_gt_pixel_weight=False,
normalize_uint8=False,
normalize_gt_edt=False,
):
assert input_type is not None
raw_img_paths = list(df["raw"])
Expand Down Expand Up @@ -63,6 +64,7 @@ def assemble_dataset(
raw_df=df,
use_gt_pixel_weight=use_gt_pixel_weight,
normalize_uint8=normalize_uint8,
normalize_gt_edt=normalize_gt_edt,
)
return dataset

Expand Down
10 changes: 5 additions & 5 deletions livecellx/model_zoo/segmentation/sc_correction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
bg_val=0,
use_gt_pixel_weight=False,
force_no_edt_aug=False,
normalize_gt_mask=False,
normalize_gt_edt=False,
):
"""_summary_
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(
]

self.force_no_edt_aug = force_no_edt_aug
self.normalize_gt_mask = normalize_gt_mask
self.normalize_gt_edt = normalize_gt_edt

def get_raw_seg(self, idx) -> np.ndarray:
return np.array(Image.open(self.raw_seg_paths[idx]))
Expand Down Expand Up @@ -269,7 +269,7 @@ def prepare_and_augment_data(
force_no_edt_aug: bool,
apply_gt_seg_edt: bool,
transform=None,
normalize_gt_mask=False,
normalize_gt_edt=False,
):
"""
Receive the loaded data (in PIL/np form), convert to tensors, and apply
Expand Down Expand Up @@ -422,7 +422,7 @@ def prepare_and_augment_data(
else:
gt_mask_edt = gt_label_edt

if normalize_gt_mask:
if normalize_gt_edt:
gt_mask_edt = normalize_edt(gt_mask_edt, edt_max=5)

aug_diff_overseg = aug_diff_img < 0
Expand Down Expand Up @@ -477,7 +477,7 @@ def __getitem__(self, idx: int):
self.force_no_edt_aug,
self.apply_gt_seg_edt,
self.transform,
normalize_gt_mask=self.normalize_gt_mask,
normalize_gt_edt=self.normalize_gt_edt,
)

# Attach the index
Expand Down
6 changes: 3 additions & 3 deletions livecellx/model_zoo/segmentation/train_csn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def parse_args():
parser.add_argument("--aux-loss-weight", default=0.5, type=float)
parser.add_argument("--normalize_uint8", default=False, action="store_true")
parser.add_argument("--torch_seed", default=237, type=int)
parser.add_argument("--normalize_gt_mask", default=False, action="store_true")
parser.add_argument("--normalize_gt_edt", default=False, action="store_true")

args = parser.parse_args()

# convert string to list
Expand Down Expand Up @@ -183,7 +183,7 @@ def df2dataset(df):
raw_df=df,
subdirs=subdirs,
use_gt_pixel_weight=args.use_gt_pixel_weight,
normalize_gt_mask=args.normalize_gt_mask,
normalize_gt_edt=args.normalize_gt_edt,
)
return dataset

Expand Down

0 comments on commit a842cdd

Please sign in to comment.