Skip to content

Commit

Permalink
various losses experiment setup
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanTodoran committed May 7, 2024
1 parent 11172b8 commit 88874f3
Showing 1 changed file with 46 additions and 22 deletions.
68 changes: 46 additions & 22 deletions book/chapters/masking_distributed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -973,14 +973,22 @@
"import monai\n",
"import torchvision.ops as ops\n",
"\n",
"def sam_lc_loss(pred_mask, ground_truth):\n",
" dice_loss = dice_loss_fn(pred_mask, ground_truth)\n",
" focal_loss = ops.sigmoid_focal_loss(pred_mask, ground_truth, reduction=\"mean\")\n",
" combined_loss = (19 * focal_loss / 20) + (1 * dice_loss / 20) # Ratio based on SAM authors' findings.\n",
" return combined_loss\n",
"def focal_loss_fn(pred_mask, ground_truth):\n",
" return ops.sigmoid_focal_loss(pred_mask, ground_truth, reduction=\"mean\")\n",
"\n",
"dice_loss_fn = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction=\"mean\") # type: ignore\n",
"\n",
"def get_linear_comb_loss(focal_loss_ratio):\n",
" dice_loss_ratio = 1 - focal_loss_ratio\n",
"\n",
" def lc_loss(pred_mask, ground_truth):\n",
" dice_loss = dice_loss_fn(pred_mask, ground_truth)\n",
" focal_loss = focal_loss_fn(pred_mask, ground_truth)\n",
" combined_loss = (focal_loss_ratio * focal_loss) + (dice_loss_ratio * dice_loss)\n",
" return combined_loss\n",
"\n",
" return lc_loss\n",
"\n",
"class LossFunctionsConfig():\n",
" def __init__(self, names: list[str], funcs: list):\n",
" self.names = names\n",
Expand Down Expand Up @@ -1348,18 +1356,34 @@
"\n",
"print(f\"Training started at {get_current_time()}\")\n",
"\n",
"# All provided loss functions will be measured.\n",
"# model is trained on the first loss function.\n",
"# All provided loss functions will be measured. Model is \n",
"# only trained with backprop of the first loss function.\n",
"loss_config = LossFunctionsConfig(\n",
" [\"dice\"], \n",
" [dice_loss_fn],\n",
" [\n",
" \"dice\",\n",
" \"1:20\",\n",
" \"1:3\",\n",
" \"1:1\",\n",
" \"3:1\",\n",
" \"19:1\", # Ratio found to be the best by the SAM authors\n",
" \"focal\",\n",
" ],\n",
" [\n",
" dice_loss_fn,\n",
" get_linear_comb_loss(1 / 20),\n",
" get_linear_comb_loss(1 / 3),\n",
" get_linear_comb_loss(1 / 1),\n",
" get_linear_comb_loss(3 / 1),\n",
" get_linear_comb_loss(19 / 20),\n",
" focal_loss_fn,\n",
" ],\n",
")\n",
"\n",
"prompt_args = {\n",
" # Control points version:\n",
" \"prompt_type\": PromptType.CONTROL_POINTS,\n",
" \"num_positive\": 5,\n",
" \"num_negative\": 0, \n",
" \"num_negative\": 0,\n",
" \"erode\": True,\n",
"\n",
" # Bounding boxes version:\n",
Expand All @@ -1368,18 +1392,18 @@
"}\n",
"\n",
"args = (\n",
" \"base\", # model_size\n",
" \"AdamW\", # optimizer_name\n",
" loss_config, # loss_config\n",
" 7e-6, # learning_rate\n",
" 2e-4, # weight_decay\n",
" 5, # batch_size (batch size per process is batch_size / num_processes)\n",
" prompt_args, # prompt_args\n",
" 50, # num_epochs\n",
" \"fp16\", # mixed_precision (\"no\" for full precision)\n",
" 42, # seed\n",
" None, # load_checkpoint (string path to checkpoint or None)\n",
" None, # model_path_name (override model path or None)\n",
" \"base\", # model_size\n",
" \"AdamW\", # optimizer_name\n",
" loss_config, # loss_config\n",
" 7e-6, # learning_rate\n",
" 2e-4, # weight_decay\n",
" 5, # batch_size (batch size per process is batch_size / num_processes)\n",
" prompt_args, # prompt_args\n",
" 50, # num_epochs\n",
" \"fp16\", # mixed_precision (\"no\" for full precision)\n",
" 42, # seed\n",
" None, # load_checkpoint (string path to checkpoint or None)\n",
" None, # model_path_name (override model path or None)\n",
")\n",
"notebook_launcher(training_loop, args, num_processes=1)\n",
"print(f\"Training ended at {get_current_time()}\")"
Expand Down

0 comments on commit 88874f3

Please sign in to comment.