diff --git a/book/chapters/masking_distributed.ipynb b/book/chapters/masking_distributed.ipynb index 23a37bc..7ef8b83 100644 --- a/book/chapters/masking_distributed.ipynb +++ b/book/chapters/masking_distributed.ipynb @@ -973,14 +973,22 @@ "import monai\n", "import torchvision.ops as ops\n", "\n", - "def sam_lc_loss(pred_mask, ground_truth):\n", - " dice_loss = dice_loss_fn(pred_mask, ground_truth)\n", - " focal_loss = ops.sigmoid_focal_loss(pred_mask, ground_truth, reduction=\"mean\")\n", - " combined_loss = (19 * focal_loss / 20) + (1 * dice_loss / 20) # Ratio based on SAM authors' findings.\n", - " return combined_loss\n", + "def focal_loss_fn(pred_mask, ground_truth):\n", + " return ops.sigmoid_focal_loss(pred_mask, ground_truth, reduction=\"mean\")\n", "\n", "dice_loss_fn = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction=\"mean\") # type: ignore\n", "\n", + "def get_linear_comb_loss(focal_loss_ratio):\n", + " dice_loss_ratio = 1 - focal_loss_ratio\n", + "\n", + " def lc_loss(pred_mask, ground_truth):\n", + " dice_loss = dice_loss_fn(pred_mask, ground_truth)\n", + " focal_loss = focal_loss_fn(pred_mask, ground_truth)\n", + " combined_loss = (focal_loss_ratio * focal_loss) + (dice_loss_ratio * dice_loss)\n", + " return combined_loss\n", + "\n", + " return lc_loss\n", + "\n", "class LossFunctionsConfig():\n", " def __init__(self, names: list[str], funcs: list):\n", " self.names = names\n", @@ -1348,18 +1356,34 @@ "\n", "print(f\"Training started at {get_current_time()}\")\n", "\n", - "# All provided loss functions will be measured.\n", - "# model is trained on the first loss function.\n", + "# All provided loss functions will be measured. Model is \n", + "# only trained with backprop of the first loss function.\n", "loss_config = LossFunctionsConfig(\n", - " [\"dice\"], \n", - " [dice_loss_fn],\n", + " [\n", + " \"dice\",\n", + " \"1:20\",\n", + " \"1:3\",\n", + " \"1:1\",\n", + " \"3:1\",\n", + " \"19:1\", # Ratio found to be the best by the SAM authors\n", + " \"focal\",\n", + " ],\n", + " [\n", + " dice_loss_fn,\n", + " get_linear_comb_loss(1 / 20),\n", + " get_linear_comb_loss(1 / 3),\n", + " get_linear_comb_loss(1 / 1),\n", + " get_linear_comb_loss(3 / 1),\n", + " get_linear_comb_loss(19 / 20),\n", + " focal_loss_fn,\n", + " ],\n", ")\n", "\n", "prompt_args = {\n", " # Control points version:\n", " \"prompt_type\": PromptType.CONTROL_POINTS,\n", " \"num_positive\": 5,\n", - " \"num_negative\": 0, \n", + " \"num_negative\": 0,\n", " \"erode\": True,\n", "\n", " # Bounding boxes version:\n", @@ -1368,18 +1392,18 @@ "}\n", "\n", "args = (\n", - " \"base\", # model_size\n", - " \"AdamW\", # optimizer_name\n", - " loss_config, # loss_config\n", - " 7e-6, # learning_rate\n", - " 2e-4, # weight_decay\n", - " 5, # batch_size (batch size per process is batch_size / num_processes)\n", - " prompt_args, # prompt_args\n", - " 50, # num_epochs\n", - " \"fp16\", # mixed_precision (\"no\" for full precision)\n", - " 42, # seed\n", - " None, # load_checkpoint (string path to checkpoint or None)\n", - " None, # model_path_name (override model path or None)\n", + " \"base\", # model_size\n", + " \"AdamW\", # optimizer_name\n", + " loss_config, # loss_config\n", + " 7e-6, # learning_rate\n", + " 2e-4, # weight_decay\n", + " 5, # batch_size (batch size per process is batch_size / num_processes)\n", + " prompt_args, # prompt_args\n", + " 50, # num_epochs\n", + " \"fp16\", # mixed_precision (\"no\" for full precision)\n", + " 42, # seed\n", + " None, # load_checkpoint (string path to checkpoint or None)\n", + " None, # model_path_name (override model path or None)\n", ")\n", "notebook_launcher(training_loop, args, num_processes=1)\n", "print(f\"Training ended at {get_current_time()}\")"