Skip to content

Commit

Permalink
add pixel-wise weights for boundaries; update augV5
Browse files Browse the repository at this point in the history
  • Loading branch information
dummyindex committed Mar 29, 2024
1 parent 90e6fdc commit e40274f
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 10 deletions.
4 changes: 2 additions & 2 deletions livecellx/model_zoo/segmentation/csn_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def gen_train_transform_v5(
transforms.RandomVerticalFlip(),
transforms.RandomAffine(degrees=degrees, translate=translation_range, scale=scale, shear=10),
transforms.GaussianBlur(kernel_size=3),
transforms.Resize((412, 412)),
transforms.Normalize([0.485], [0.229]),
transforms.Resize((256, 256)),
transforms.Normalize([127], [30]),
]
)
return train_transforms
5 changes: 4 additions & 1 deletion livecellx/model_zoo/segmentation/eval_csn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from livecellx.model_zoo.segmentation.sc_correction_dataset import CorrectSegNetDataset


def assemble_dataset(df: pd.DataFrame, apply_gt_seg_edt=False, exclude_raw_input_bg=False, input_type=None):
def assemble_dataset(
df: pd.DataFrame, apply_gt_seg_edt=False, exclude_raw_input_bg=False, input_type=None, use_gt_pixel_weight=False
):
assert input_type is not None
raw_img_paths = list(df["raw"])
scaled_seg_mask_paths = list(df["seg"])
Expand All @@ -47,6 +49,7 @@ def assemble_dataset(df: pd.DataFrame, apply_gt_seg_edt=False, exclude_raw_input
exclude_raw_input_bg=exclude_raw_input_bg,
input_type=input_type,
raw_df=df,
use_gt_pixel_weight=use_gt_pixel_weight,
)
return dataset

Expand Down
79 changes: 72 additions & 7 deletions livecellx/model_zoo/segmentation/sc_correction_aux.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,43 @@

LOG_PROGRESS_BAR = False

import torch
import torch.nn.functional as F


def weighted_mse_loss(predict, target, weights=None):
"""
Compute the weighted MSE loss with an optional weight map for the first channel.
Parameters:
- input: Tensor of predicted values (batch_size, channels, height, width).
- target: Tensor of target values with the same shape as input.
- weights: Optional. Tensor of weights for the first channel (batch_size, 1, height, width).
If None, no weights are applied and standard MSE loss is calculated.
Returns:
- loss: Scalar tensor representing the weighted MSE loss.
"""
if weights is not None:
# Ensure the weights can be broadcasted to match the input shape
# Weights for channels other than the first are assumed to be 1
expanded_weights = torch.ones_like(predict)
expanded_weights[:, 0, :, :] = weights[:, 0, :, :] # Apply weights to the first channel

# Calculate squared differences
squared_diff = (predict - target) ** 2

# Apply weights
weighted_squared_diff = squared_diff * expanded_weights

# Calculate mean of the weighted squared differences
loss = weighted_squared_diff.mean()
else:
# If no weights are provided, calculate standard MSE loss
loss = F.mse_loss(predict, target, reduction="mean")

return loss


class CorrectSegNetAux(LightningModule):
def __init__(
Expand All @@ -32,7 +69,7 @@ def __init__(
batch_size=5,
class_weights=[1, 1, 1],
model_type=None,
num_workers=16,
num_workers=32,
train_input_paths=None,
train_transforms=None,
seed=99,
Expand Down Expand Up @@ -151,7 +188,9 @@ def forward(self, x: torch.Tensor):
else:
return x

def compute_loss(self, output: torch.tensor, target: torch.tensor, aux_out=None, aux_target=None):
def compute_loss(
self, output: torch.tensor, target: torch.tensor, aux_out=None, aux_target=None, gt_pixel_weight=None
):
"""Compute loss fuction
Parameters
Expand All @@ -178,27 +217,53 @@ def compute_loss(self, output: torch.tensor, target: torch.tensor, aux_out=None,
), "seg_output shape should be batch_size x num_classes x height x width, got %s" % str(seg_output.shape)

if self.loss_type == "CE":
return self.loss_func(seg_output, target), aux_loss
seg_loss = self.loss_func(seg_output, target)
elif self.loss_type == "MSE":
total_loss = 0
num_classes = seg_output.shape[1]
for cat_dim in range(0, num_classes):
temp_target = target[:, cat_dim, ...]
temp_output = seg_output[:, cat_dim, ...]
total_loss += self.loss_func(temp_output, temp_target) * self.class_weights[cat_dim]
return total_loss, aux_loss
total_loss += weighted_mse_loss(temp_output, temp_target) * self.class_weights[cat_dim]
seg_loss = total_loss
elif self.loss_type == "BCE":
# # Debugging
# print("*" * 40)
# print("Dimensions:")
# print("seg_output shape: ", seg_output.shape)
# print("target shape: ", target.shape)
# print("*" * 40)
# if gt_pixel_weight is not None:
# print("gt_pixel_weight shape: ", gt_pixel_weight.shape)
if gt_pixel_weight is not None:
# Repeat to match 3 channels of gt (seg and two OU masks): gt_pixel_weight shape: 2, 412, 412 -> 2, 3, 412, 412
gt_pixel_weight_repeated = gt_pixel_weight.unsqueeze(1).repeat(1, 3, 1, 1)
# assert len(gt_pixel_weight_repeated.shape) == 4
gt_pixel_weight_permuted = gt_pixel_weight_repeated.permute(0, 2, 3, 1)
else:
gt_pixel_weight_permuted = None
seg_output = seg_output.permute(0, 2, 3, 1)
target = target.permute(0, 2, 3, 1)
return self.loss_func(seg_output, target), aux_loss
self.loss_func = torch.nn.BCEWithLogitsLoss(
weight=gt_pixel_weight_permuted, pos_weight=torch.tensor(self.class_weights).cuda()
)

seg_loss = self.loss_func(seg_output, target)
else:
raise NotImplementedError("Loss:%s not implemented", self.loss_type)

return seg_loss, aux_loss

def training_step(self, batch, batch_idx):
# print("[train_step] x shape: ", batch["input"].shape)
# print("[train_step] y shape: ", batch["gt_mask"].shape)
x, y = batch["input"], batch["gt_mask"]
aux_target = batch["ou_aux"]
gt_pixel_weight = batch["gt_pixel_weight"]
output, aux_out = self(x)
seg_loss, aux_loss = self.compute_loss(output, y, aux_out=aux_out, aux_target=aux_target)
seg_loss, aux_loss = self.compute_loss(
output, y, aux_out=aux_out, aux_target=aux_target, gt_pixel_weight=gt_pixel_weight
)
loss = seg_loss + self.aux_loss_weight * aux_loss
predicted_labels = torch.argmax(output, dim=1)
self.log(
Expand Down
20 changes: 20 additions & 0 deletions livecellx/model_zoo/segmentation/sc_correction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
raw_df=None,
normalize_uint8=False,
bg_val=0,
use_gt_pixel_weight=False,
):
"""_summary_
Expand Down Expand Up @@ -114,6 +115,14 @@ def __init__(
print("whether to normalize_uint8:", self.normalize_uint8)
self.bg_val = bg_val

self.use_gt_pixel_weight = use_gt_pixel_weight
print("whether to use_gt_pixel_weight:", self.use_gt_pixel_weight)
if self.use_gt_pixel_weight:
self.gt_pixel_weight_paths = [
str(Path(path).parent.parent / "gt_pixel_weight" / (str(Path(path).stem) + "_weight.npy"))
for path in self.gt_mask_paths
]

def get_raw_seg(self, idx) -> np.array:
return np.array(Image.open(self.raw_seg_paths[idx]))

Expand Down Expand Up @@ -156,6 +165,14 @@ def __getitem__(self, idx):
gt_label_mask__np = np.array(Image.open(self.gt_label_mask_paths[idx]))
gt_label_mask = torch.tensor(gt_label_mask__np.copy()).long()

if self.use_gt_pixel_weight:
# Read the pixel weight map from the <gt_pixel_weight> subfolder. weights are in npy format
gt_pixel_weight = np.load(self.gt_pixel_weight_paths[idx])
else:
# Ones for all pixels
gt_pixel_weight = np.ones_like(gt_label_mask__np)
gt_pixel_weight = torch.tensor(gt_pixel_weight).float()

# transform to edt for inputs before augmentation
if self.input_type == "edt_v0":
scaled_seg_mask = self.label_mask_to_edt(scaled_seg_mask)
Expand All @@ -168,6 +185,7 @@ def __getitem__(self, idx):
gt_mask.float(),
aug_diff_img,
gt_label_mask,
gt_pixel_weight,
],
dim=0,
)
Expand All @@ -178,6 +196,7 @@ def __getitem__(self, idx):
augmented_raw_transformed_img = concat_img[1]
augmented_scaled_seg_mask = concat_img[2]
augmented_gt_label_mask = concat_img[5].long()
augmented_gt_pixel_weight = concat_img[6]

if self.input_type == "raw_aug_seg":
input_img = torch.stack(
Expand Down Expand Up @@ -257,6 +276,7 @@ def __getitem__(self, idx):
"idx": idx,
"gt_label_mask": augmented_gt_label_mask,
"ou_aux": ou_aux,
"gt_pixel_weight": augmented_gt_pixel_weight,
}
if self.apply_gt_seg_edt:
res["gt_mask_edt"] = gt_mask_edt
Expand Down
4 changes: 4 additions & 0 deletions livecellx/model_zoo/segmentation/train_csn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def parse_args():
)
parser.add_argument("--ou_aux", dest="ou_aux", default=False, action="store_true")
parser.add_argument("--aug-ver", default="v0", type=str, help="The version of the augmentation to use.")
parser.add_argument("--use-gt-pixel-weight", default=False, action="store_true")

args = parser.parse_args()

Expand Down Expand Up @@ -124,6 +125,8 @@ def main_train():
train_transforms = csn_configs.gen_train_transform_v3(degrees, translation_range, args.aug_scale)
elif args.aug_ver == "v4":
train_transforms = csn_configs.gen_train_transform_v4(degrees, translation_range, args.aug_scale)
elif args.aug_ver == "v5":
train_transforms = csn_configs.gen_train_transform_v5(degrees, translation_range, args.aug_scale)
else:
raise ValueError("Unknown augmentation version")

Expand Down Expand Up @@ -154,6 +157,7 @@ def df2dataset(df):
exclude_raw_input_bg=args.exclude_raw_input_bg,
raw_df=df,
subdirs=subdirs,
use_gt_pixel_weight=args.use_gt_pixel_weight,
)
return dataset

Expand Down

0 comments on commit e40274f

Please sign in to comment.