Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: example notebooks to train and predict #38

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

rabiaedayilmaz
Copy link
Contributor

  • created examples folder
  • added quick start colab notebook that trains&validates 2d refuge data

@ibinti
Copy link

ibinti commented Sep 10, 2024

  • created examples folder

    • added quick start colab notebook that trains&validates 2d refuge data

i tried colab notebook with a100, l4, and t4. notebook mysteriously crashed. surprisingly, on kaggle, notebook did run with minor fix /content -> /kaggle/working but got this error

/kaggle/working/Medical-SAM2/sam2_train/modeling/sam/transformer.py:22: UserWarning: Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.
  OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
INFO:root:Namespace(net='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=True, train_vis=False, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='REFUGE', sam_ckpt='/kaggle/working/checkpoints/sam2_hiera_tiny.pt', sam_config='sam2_hiera_t', video_length=2, b=2, lr=0.0001, weights=0, multimask_output=1, memory_bank_size=16, data_path='/kaggle/working/data/REFUGE', path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Samples'})
Namespace(net='sam2', encoder='vit_b', exp_name='REFUGE_MedSAM2', vis=True, train_vis=False, prompt='bbox', prompt_freq=2, pretrain=None, val_freq=1, gpu=True, gpu_device=0, image_size=1024, out_size=1024, distributed='none', dataset='REFUGE', sam_ckpt='/kaggle/working/checkpoints/sam2_hiera_tiny.pt', sam_config='sam2_hiera_t', video_length=2, b=2, lr=0.0001, weights=0, multimask_output=1, memory_bank_size=16, data_path='/kaggle/working/data/REFUGE', path_helper={'prefix': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32', 'ckpt_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Model', 'log_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Log', 'sample_path': 'logs/REFUGE_MedSAM2_2024_09_10_02_23_32/Samples'})
Traceback (most recent call last):                                              
  File "/kaggle/working/Medical-SAM2/train_2d.py", line 124, in <module>
    main()
  File "/kaggle/working/Medical-SAM2/train_2d.py", line 97, in main
    tol, (eiou, edice) = function.validation_sam(args, nice_test_loader, epoch, net, writer)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kaggle/working/Medical-SAM2/func_2d/function.py", line 335, in validation_sam
    vision_feats_temp = vision_feats[-1].permute(1, 0, 2).view(B, -1, 64, 64) 
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
ERROR conda.cli.main_run:execute(125): `conda run bash -c python train_2d.py -net sam2 -exp_name REFUGE_MedSAM2 -vis 1 -sam_ckpt /kaggle/working/checkpoints/sam2_hiera_tiny.pt -sam_config sam2_hiera_t -image_size 1024 -out_size 1024 -b 2 -val_freq 1 -dataset REFUGE -data_path /kaggle/working/data/REFUGE` failed. (See above for error)

anyway thanks. i will try notebook way without using conda.

@rabiaedayilmaz
Copy link
Contributor Author

Actually I removed the part where I added reshape func, since my PR was merged :/ Kaggle one interesting... I guess the problem related to GPU architecture differences? The note from my local notebook:
`#rewrite this file because it breaks

changed view to reshape in vision_feats_temp var on line 104

#vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64)`

So for colab, can you add this cell right before training command and try again? @ibinti

%%writefile /content/Medical-SAM2/func_2d/function.py


import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

import cfg
from conf import settings
from func_2d.utils import *
import pandas as pd


args = cfg.parse_args()

GPUdevice = torch.device('cuda', args.gpu_device)
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
mask_type = torch.float32

torch.backends.cudnn.benchmark = True


def train_sam(args, net: nn.Module, optimizer, train_loader, epoch, writer):

    # use bfloat16 for the entire notebook
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

    if torch.cuda.get_device_properties(0).major >= 8:
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True


    # train mode
    net.train()
    optimizer.zero_grad()

    # init
    epoch_loss = 0
    memory_bank_list = []
    lossfunc = criterion_G
    feat_sizes = [(256, 256), (128, 128), (64, 64)]


    with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
        for ind, pack in enumerate(train_loader):

            to_cat_memory = []
            to_cat_memory_pos = []
            to_cat_image_embed = []

            # input image and gt masks
            imgs = pack['image'].to(dtype = mask_type, device = GPUdevice)
            masks = pack['mask'].to(dtype = mask_type, device = GPUdevice)
            name = pack['image_meta_dict']['filename_or_obj']

            # click prompt: unsqueeze to indicate only one click, add more click across this dimension
            if 'pt' in pack:
                pt_temp = pack['pt'].to(device = GPUdevice)
                pt = pt_temp.unsqueeze(1)
                point_labels_temp = pack['p_label'].to(device = GPUdevice)
                point_labels = point_labels_temp.unsqueeze(1)
                coords_torch = torch.as_tensor(pt, dtype=torch.float, device=GPUdevice)
                labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
            else:
                coords_torch = None
                labels_torch = None

            '''Train image encoder'''
            backbone_out = net.forward_image(imgs)
            _, vision_feats, vision_pos_embeds, _ = net._prepare_backbone_features(backbone_out)
            # dimension hint for your future use
            # vision_feats: list: length = 3
            # vision_feats[0]: torch.Size([65536, batch, 32])
            # vision_feats[1]: torch.Size([16384, batch, 64])
            # vision_feats[2]: torch.Size([4096, batch, 256])
            # vision_pos_embeds[0]: torch.Size([65536, batch, 256])
            # vision_pos_embeds[1]: torch.Size([16384, batch, 256])
            # vision_pos_embeds[2]: torch.Size([4096, batch, 256])



            '''Train memory attention to condition on meomory bank'''
            B = vision_feats[-1].size(1)  # batch size

            if len(memory_bank_list) == 0:
                vision_feats[-1] = vision_feats[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")
                vision_pos_embeds[-1] = vision_pos_embeds[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")

            else:
                for element in memory_bank_list:
                    to_cat_memory.append((element[0]).cuda(non_blocking=True).flatten(2).permute(2, 0, 1)) # maskmem_features
                    to_cat_memory_pos.append((element[1]).cuda(non_blocking=True).flatten(2).permute(2, 0, 1)) # maskmem_pos_enc
                    to_cat_image_embed.append((element[3]).cuda(non_blocking=True)) # image_embed

                memory_stack_ori = torch.stack(to_cat_memory, dim=0)
                memory_pos_stack_ori = torch.stack(to_cat_memory_pos, dim=0)
                image_embed_stack_ori = torch.stack(to_cat_image_embed, dim=0)

                vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64)
                vision_feats_temp = vision_feats_temp.reshape(B, -1)

                image_embed_stack_ori = F.normalize(image_embed_stack_ori, p=2, dim=1)
                vision_feats_temp = F.normalize(vision_feats_temp, p=2, dim=1)
                similarity_scores = torch.mm(image_embed_stack_ori, vision_feats_temp.t()).t()

                similarity_scores = F.softmax(similarity_scores, dim=1)
                sampled_indices = torch.multinomial(similarity_scores, num_samples=B, replacement=True).squeeze(1)  # Shape [batch_size, 16]

                memory_stack_ori_new = (memory_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
                memory = memory_stack_ori_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))

                memory_pos_stack_new = (memory_pos_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
                memory_pos = memory_pos_stack_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))


                vision_feats[-1] = net.memory_attention(
                    curr=[vision_feats[-1]],
                    curr_pos=[vision_pos_embeds[-1]],
                    memory=memory,
                    memory_pos=memory_pos,
                    num_obj_ptr_tokens=0
                    )


            feats = [feat.permute(1, 2, 0).view(B, -1, *feat_size)
                     for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]

            image_embed = feats[-1]
            high_res_feats = feats[:-1]

            # feats[0]: torch.Size([batch, 32, 256, 256]) #high_res_feats part1
            # feats[1]: torch.Size([batch, 64, 128, 128]) #high_res_feats part2
            # feats[2]: torch.Size([batch, 256, 64, 64]) #image_embed


            '''prompt encoder'''
            with torch.no_grad():
                if (ind%5) == 0:
                    points=(coords_torch, labels_torch) # input shape: ((batch, n, 2), (batch, n))
                    flag = True
                else:
                    points=None
                    flag = False

                se, de = net.sam_prompt_encoder(
                    points=points, #(coords_torch, labels_torch)
                    boxes=None,
                    masks=None,
                    batch_size=B,
                )
            # dimension hint for your future use
            # se: torch.Size([batch, n+1, 256])
            # de: torch.Size([batch, 256, 64, 64])




            '''train mask decoder'''
            low_res_multimasks, iou_predictions, sam_output_tokens, object_score_logits = net.sam_mask_decoder(
                    image_embeddings=image_embed,
                    image_pe=net.sam_prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings=se,
                    dense_prompt_embeddings=de,
                    multimask_output=False, # args.multimask_output if you want multiple masks
                    repeat_image=False,  # the image is already batched
                    high_res_features = high_res_feats
                )
            # dimension hint for your future use
            # low_res_multimasks: torch.Size([batch, multimask_output, 256, 256])
            # iou_predictions.shape:torch.Size([batch, multimask_output])
            # sam_output_tokens.shape:torch.Size([batch, multimask_output, 256])
            # object_score_logits.shape:torch.Size([batch, 1])


            # resize prediction
            pred = F.interpolate(low_res_multimasks,size=(args.out_size,args.out_size))
            high_res_multimasks = F.interpolate(low_res_multimasks, size=(args.image_size, args.image_size),
                                                mode="bilinear", align_corners=False)


            '''memory encoder'''
            # new caluculated memory features
            maskmem_features, maskmem_pos_enc = net._encode_new_memory(
                current_vision_feats=vision_feats,
                feat_sizes=feat_sizes,
                pred_masks_high_res=high_res_multimasks,
                is_mask_from_pts=flag)
            # dimension hint for your future use
            # maskmem_features: torch.Size([batch, 64, 64, 64])
            # maskmem_pos_enc: [torch.Size([batch, 64, 64, 64])]

            maskmem_features = maskmem_features.to(torch.bfloat16)
            maskmem_features = maskmem_features.to(device=GPUdevice, non_blocking=True)
            maskmem_pos_enc = maskmem_pos_enc[0].to(torch.bfloat16)
            maskmem_pos_enc = maskmem_pos_enc.to(device=GPUdevice, non_blocking=True)


            # add single maskmem_features, maskmem_pos_enc, iou
            if len(memory_bank_list) < args.memory_bank_size:
                for batch in range(maskmem_features.size(0)):
                    memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)).detach(),
                                             (maskmem_pos_enc[batch].unsqueeze(0)).detach(),
                                             iou_predictions[batch, 0],
                                             image_embed[batch].reshape(-1).detach()])

            else:
                for batch in range(maskmem_features.size(0)):

                    # current simlarity matrix in existing memory bank
                    memory_bank_maskmem_features_flatten = [element[0].reshape(-1) for element in memory_bank_list]
                    memory_bank_maskmem_features_flatten = torch.stack(memory_bank_maskmem_features_flatten)

                    # normalise
                    memory_bank_maskmem_features_norm = F.normalize(memory_bank_maskmem_features_flatten, p=2, dim=1)
                    current_similarity_matrix = torch.mm(memory_bank_maskmem_features_norm,
                                                         memory_bank_maskmem_features_norm.t())

                    # replace diagonal (diagnoal always simiarity = 1)
                    current_similarity_matrix_no_diag = current_similarity_matrix.clone()
                    diag_indices = torch.arange(current_similarity_matrix_no_diag.size(0))
                    current_similarity_matrix_no_diag[diag_indices, diag_indices] = float('-inf')

                    # first find the minimum similarity from memory feature and the maximum similarity from memory bank
                    single_key_norm = F.normalize(maskmem_features[batch].reshape(-1), p=2, dim=0).unsqueeze(1)
                    similarity_scores = torch.mm(memory_bank_maskmem_features_norm, single_key_norm).squeeze()
                    min_similarity_index = torch.argmin(similarity_scores)
                    max_similarity_index = torch.argmax(current_similarity_matrix_no_diag[min_similarity_index])

                    # replace with less similar object
                    if similarity_scores[min_similarity_index] < current_similarity_matrix_no_diag[min_similarity_index][max_similarity_index]:
                        # soft iou, not stricly greater than current iou
                        if iou_predictions[batch, 0] > memory_bank_list[max_similarity_index][2] - 0.1:
                            memory_bank_list.pop(max_similarity_index)
                            memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)).detach(),
                                                     (maskmem_pos_enc[batch].unsqueeze(0)).detach(),
                                                     iou_predictions[batch, 0],
                                                     image_embed[batch].reshape(-1).detach()])

            # backpropagation
            loss = lossfunc(pred, masks)
            pbar.set_postfix(**{'loss (batch)': loss.item()})
            epoch_loss += loss.item()

            loss.backward()
            optimizer.step()

            optimizer.zero_grad()

            pbar.update()

    return epoch_loss/len(train_loader)




def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):

    # use bfloat16 for the entire notebook
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

    if torch.cuda.get_device_properties(0).major >= 8:
        # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True


    # eval mode
    net.eval()

    n_val = len(val_loader)
    threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
    GPUdevice = torch.device('cuda:' + str(args.gpu_device))

    # init
    lossfunc = criterion_G
    memory_bank_list = []
    feat_sizes = [(256, 256), (128, 128), (64, 64)]
    total_loss = 0
    total_eiou = 0
    total_dice = 0


    with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
        for ind, pack in enumerate(val_loader):
            to_cat_memory = []
            to_cat_memory_pos = []
            to_cat_image_embed = []

            name = pack['image_meta_dict']['filename_or_obj']
            imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
            masks = pack['mask'].to(dtype = torch.float32, device = GPUdevice)


            if 'pt' in pack:
                pt_temp = pack['pt'].to(device = GPUdevice)
                pt = pt_temp.unsqueeze(1)
                point_labels_temp = pack['p_label'].to(device = GPUdevice)
                point_labels = point_labels_temp.unsqueeze(1)
                coords_torch = torch.as_tensor(pt, dtype=torch.float, device=GPUdevice)
                labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
            else:
                coords_torch = None
                labels_torch = None



            '''test'''
            with torch.no_grad():

                """ image encoder """
                backbone_out = net.forward_image(imgs)
                _, vision_feats, vision_pos_embeds, _ = net._prepare_backbone_features(backbone_out)
                B = vision_feats[-1].size(1)

                """ memory condition """
                if len(memory_bank_list) == 0:
                    vision_feats[-1] = vision_feats[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")
                    vision_pos_embeds[-1] = vision_pos_embeds[-1] + torch.nn.Parameter(torch.zeros(1, B, net.hidden_dim)).to(device="cuda")

                else:
                    for element in memory_bank_list:
                        maskmem_features = element[0]
                        maskmem_pos_enc = element[1]
                        to_cat_memory.append(maskmem_features.cuda(non_blocking=True).flatten(2).permute(2, 0, 1))
                        to_cat_memory_pos.append(maskmem_pos_enc.cuda(non_blocking=True).flatten(2).permute(2, 0, 1))
                        to_cat_image_embed.append((element[3]).cuda(non_blocking=True)) # image_embed

                    memory_stack_ori = torch.stack(to_cat_memory, dim=0)
                    memory_pos_stack_ori = torch.stack(to_cat_memory_pos, dim=0)
                    image_embed_stack_ori = torch.stack(to_cat_image_embed, dim=0)

                    vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64)
                    vision_feats_temp = vision_feats_temp.reshape(B, -1)

                    image_embed_stack_ori = F.normalize(image_embed_stack_ori, p=2, dim=1)
                    vision_feats_temp = F.normalize(vision_feats_temp, p=2, dim=1)
                    similarity_scores = torch.mm(image_embed_stack_ori, vision_feats_temp.t()).t()

                    similarity_scores = F.softmax(similarity_scores, dim=1)
                    sampled_indices = torch.multinomial(similarity_scores, num_samples=B, replacement=True).squeeze(1)  # Shape [batch_size, 16]

                    memory_stack_ori_new = (memory_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
                    memory = memory_stack_ori_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))

                    memory_pos_stack_new = (memory_pos_stack_ori[sampled_indices].squeeze(3).permute(1, 2, 0, 3))
                    memory_pos = memory_pos_stack_new.reshape(-1, memory_stack_ori_new.size(2), memory_stack_ori_new.size(3))



                    vision_feats[-1] = net.memory_attention(
                        curr=[vision_feats[-1]],
                        curr_pos=[vision_pos_embeds[-1]],
                        memory=memory,
                        memory_pos=memory_pos,
                        num_obj_ptr_tokens=0
                        )

                feats = [feat.permute(1, 2, 0).view(B, -1, *feat_size)
                        for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]

                image_embed = feats[-1]
                high_res_feats = feats[:-1]

                """ prompt encoder """
                if (ind%5) == 0:
                    flag = True
                    points = (coords_torch, labels_torch)

                else:
                    flag = False
                    points = None

                se, de = net.sam_prompt_encoder(
                    points=points,
                    boxes=None,
                    masks=None,
                    batch_size=B,
                )

                low_res_multimasks, iou_predictions, sam_output_tokens, object_score_logits = net.sam_mask_decoder(
                    image_embeddings=image_embed,
                    image_pe=net.sam_prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings=se,
                    dense_prompt_embeddings=de,
                    multimask_output=False,
                    repeat_image=False,
                    high_res_features = high_res_feats
                )

                # prediction
                pred = F.interpolate(low_res_multimasks,size=(args.out_size,args.out_size))
                high_res_multimasks = F.interpolate(low_res_multimasks, size=(args.image_size, args.image_size),
                                                mode="bilinear", align_corners=False)

                """ memory encoder """
                maskmem_features, maskmem_pos_enc = net._encode_new_memory(
                    current_vision_feats=vision_feats,
                    feat_sizes=feat_sizes,
                    pred_masks_high_res=high_res_multimasks,
                    is_mask_from_pts=flag)

                maskmem_features = maskmem_features.to(torch.bfloat16)
                maskmem_features = maskmem_features.to(device=GPUdevice, non_blocking=True)
                maskmem_pos_enc = maskmem_pos_enc[0].to(torch.bfloat16)
                maskmem_pos_enc = maskmem_pos_enc.to(device=GPUdevice, non_blocking=True)


                """ memory bank """
                if len(memory_bank_list) < 16:
                    for batch in range(maskmem_features.size(0)):
                        memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)),
                                                 (maskmem_pos_enc[batch].unsqueeze(0)),
                                                 iou_predictions[batch, 0],
                                                 image_embed[batch].reshape(-1).detach()])

                else:
                    for batch in range(maskmem_features.size(0)):

                        memory_bank_maskmem_features_flatten = [element[0].reshape(-1) for element in memory_bank_list]
                        memory_bank_maskmem_features_flatten = torch.stack(memory_bank_maskmem_features_flatten)

                        memory_bank_maskmem_features_norm = F.normalize(memory_bank_maskmem_features_flatten, p=2, dim=1)
                        current_similarity_matrix = torch.mm(memory_bank_maskmem_features_norm,
                                                             memory_bank_maskmem_features_norm.t())

                        current_similarity_matrix_no_diag = current_similarity_matrix.clone()
                        diag_indices = torch.arange(current_similarity_matrix_no_diag.size(0))
                        current_similarity_matrix_no_diag[diag_indices, diag_indices] = float('-inf')

                        single_key_norm = F.normalize(maskmem_features[batch].reshape(-1), p=2, dim=0).unsqueeze(1)
                        similarity_scores = torch.mm(memory_bank_maskmem_features_norm, single_key_norm).squeeze()
                        min_similarity_index = torch.argmin(similarity_scores)
                        max_similarity_index = torch.argmax(current_similarity_matrix_no_diag[min_similarity_index])

                        if similarity_scores[min_similarity_index] < current_similarity_matrix_no_diag[min_similarity_index][max_similarity_index]:
                            if iou_predictions[batch, 0] > memory_bank_list[max_similarity_index][2] - 0.1:
                                memory_bank_list.pop(max_similarity_index)
                                memory_bank_list.append([(maskmem_features[batch].unsqueeze(0)),
                                                         (maskmem_pos_enc[batch].unsqueeze(0)),
                                                         iou_predictions[batch, 0],
                                                         image_embed[batch].reshape(-1).detach()])

                # binary mask and calculate loss, iou, dice
                total_loss += lossfunc(pred, masks)
                pred = (pred> 0.5).float()
                temp = eval_seg(pred, masks, threshold)
                total_eiou += temp[0]
                total_dice += temp[1]

                '''vis images'''
                if ind % args.vis == 0:
                    namecat = 'Test'
                    for na in name:
                        img_name = na
                        namecat = namecat + img_name + '+'
                    vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=None)

            pbar.update()

    return total_loss/ n_val , tuple([total_eiou/n_val, total_dice/n_val])

@ibinti
Copy link

ibinti commented Sep 10, 2024

So for colab, can you add this cell right before training command and try again?

i did not add the cell you suggest and run because i have already forked the repo and so can modify my fork. i commented out the line #335 of function.py and added .reshape() like you showed:

# vision_feats_temp = vision_feats[-1].permute(1, 0, 2).view(B, -1, 64, 64) 
vision_feats_temp = vision_feats[-1].permute(1, 0, 2).reshape(B, -1, 64, 64)

this fix made both kaggle p100 and colab t4 happy. here are training log outputs from both.

kaggle p100

INFO:root:Total score: 0.73396235704422, IOU: 0.062239742812431074, DICE: 0.09917049956518037 || @ epoch 0.
Total score: 0.73396235704422, IOU: 0.062239742812431074, DICE: 0.09917049956518037 || @ epoch 0.
Epoch 0: 100%|███████████| 200/200 [03:23<00:00,  1.02s/img, loss (batch)=0.119]
INFO:root:Train loss: 0.2960984502360225 || @ epoch 0.
Train loss: 0.2960984502360225 || @ epoch 0.
time_for_training  203.05405259132385
INFO:root:Total score: 0.18758922815322876, IOU: 0.6464821135604341, DICE: 0.7748540666699409 || @ epoch 0.
Total score: 0.18758922815322876, IOU: 0.6464821135604341, DICE: 0.7748540666699409 || @ epoch 0.

colab t4

INFO:root:Total score: 0.6537907123565674, IOU: 0.03657788351116989, DICE: 0.05533080002906395 || @ epoch 0.
Total score: 0.6537907123565674, IOU: 0.03657788351116989, DICE: 0.05533080002906395 || @ epoch 0.
Epoch 0: 100% 200/200 [04:01<00:00,  1.21s/img, loss (batch)=0.105]
INFO:root:Train loss: 0.2971741591766477 || @ epoch 0.
Train loss: 0.2971741591766477 || @ epoch 0.
time_for_training  241.99381804466248
INFO:root:Total score: 0.19345636665821075, IOU: 0.6315837719022945, DICE: 0.7623434653878212 || @ epoch 0.
Total score: 0.19345636665821075, IOU: 0.6315837719022945, DICE: 0.7623434653878212 || @ epoch 0.

thanks!

@ibinti
Copy link

ibinti commented Sep 11, 2024

this is to make the story complete with 3d train.
if cuda extension is built on the colab or kaggle, train_3d.py also runs without a problem. use l4 on colab to provide enough gpu memory. t4 and p100 will run out of memory after couple of steps training.

one does build cuda extension like this.

!python setup.py build_ext --inplace

where setup.py is not in the Medical-SAM2 repo, i copied one from the upstream meta segment-anything-2 repo.
add this cell right before running train_3d.py. the only minor modification is to change path to _C.so, like "sam2" => "sam2_train" in the two places for srcs and ext_modules lines. this is due to the path difference between Medical-SAM2 repo and segment-anything-2 repo.

%%writefile /content/Medical-SAM2/setup.py

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os

from setuptools import find_packages, setup

# Package metadata
NAME = "SAM 2"
VERSION = "1.0"
DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
URL = "https://github.com/facebookresearch/segment-anything-2"
AUTHOR = "Meta AI"
AUTHOR_EMAIL = "[email protected]"
LICENSE = "Apache 2.0"

# Read the contents of README file
with open("README.md", "r", encoding="utf-8") as f:
    LONG_DESCRIPTION = f.read()

# Required dependencies
REQUIRED_PACKAGES = [
    "torch>=2.3.1",
    "torchvision>=0.18.1",
    "numpy>=1.24.4",
    "tqdm>=4.66.1",
    "hydra-core>=1.3.2",
    "iopath>=0.1.10",
    "pillow>=9.4.0",
]

EXTRA_PACKAGES = {
    "demo": ["matplotlib>=3.9.1", "jupyter>=1.0.0", "opencv-python>=4.7.0"],
    "dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
}

# By default, we also build the SAM 2 CUDA extension.
# You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
# By default, we allow SAM 2 installation to proceed even with build errors.
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"

# Catch and skip errors during extension building and print a warning message
# (note that this message only shows up under verbose build mode
# "pip install -v -e ." or "python setup.py build_ext -v")
CUDA_ERROR_MSG = (
    "{}\n\n"
    "Failed to build the SAM 2 CUDA extension due to the error above. "
    "You can still use SAM 2 and it's OK to ignore the error above, although some "
    "post-processing functionality may be limited (which doesn't affect the results in most cases; "
    "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
)


def get_extensions():
    if not BUILD_CUDA:
        return []

    try:
        from torch.utils.cpp_extension import CUDAExtension

        srcs = ["sam2_train/csrc/connected_components.cu"]
        compile_args = {
            "cxx": [],
            "nvcc": [
                "-DCUDA_HAS_FP16=1",
                "-D__CUDA_NO_HALF_OPERATORS__",
                "-D__CUDA_NO_HALF_CONVERSIONS__",
                "-D__CUDA_NO_HALF2_OPERATORS__",
            ],
        }
        ext_modules = [CUDAExtension("sam2_train._C", srcs, extra_compile_args=compile_args)]
    except Exception as e:
        if BUILD_ALLOW_ERRORS:
            print(CUDA_ERROR_MSG.format(e))
            ext_modules = []
        else:
            raise e

    return ext_modules


try:
    from torch.utils.cpp_extension import BuildExtension

    class BuildExtensionIgnoreErrors(BuildExtension):

        def finalize_options(self):
            try:
                super().finalize_options()
            except Exception as e:
                print(CUDA_ERROR_MSG.format(e))
                self.extensions = []

        def build_extensions(self):
            try:
                super().build_extensions()
            except Exception as e:
                print(CUDA_ERROR_MSG.format(e))
                self.extensions = []

        def get_ext_filename(self, ext_name):
            try:
                return super().get_ext_filename(ext_name)
            except Exception as e:
                print(CUDA_ERROR_MSG.format(e))
                self.extensions = []
                return "_C.so"

    cmdclass = {
        "build_ext": (
            BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
            if BUILD_ALLOW_ERRORS
            else BuildExtension.with_options(no_python_abi_suffix=True)
        )
    }
except Exception as e:
    cmdclass = {}
    if BUILD_ALLOW_ERRORS:
        print(CUDA_ERROR_MSG.format(e))
    else:
        raise e


# Setup configuration
setup(
    name=NAME,
    version=VERSION,
    description=DESCRIPTION,
    long_description=LONG_DESCRIPTION,
    long_description_content_type="text/markdown",
    url=URL,
    author=AUTHOR,
    author_email=AUTHOR_EMAIL,
    license=LICENSE,
    packages=find_packages(exclude="notebooks"),
    package_data={"": ["*.yaml"]},  # SAM 2 configuration files
    include_package_data=True,
    install_requires=REQUIRED_PACKAGES,
    extras_require=EXTRA_PACKAGES,
    python_requires=">=3.10.0",
    ext_modules=get_extensions(),
    cmdclass=cmdclass,
)

this is the log output from train_3d.py on colab l4:

Epoch 0:   0% 0/24 [00:00<?, ?img/s]/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
Epoch 0: 100% 24/24 [00:40<00:00,  1.68s/img, loss (batch)=0.00301]
INFO:root:Train loss: 0.051431525489487206, 0.007973176463565324, 0.09488987206714228 || @ epoch 0.
Train loss: 0.051431525489487206, 0.007973176463565324, 0.09488987206714228 || @ epoch 0.
time_for_training  40.224876165390015
INFO:root:Total score: 0.16452749073505402, IOU: 0.8724716256719525, DICE: 0.9101958117103393 || @ epoch 0.
Total score: 0.16452749073505402, IOU: 0.8724716256719525, DICE: 0.9101958117103393 || @ epoch 0.

it would be nice if setup.py is included in the Medical-SAM2, and provide an instruction for anyone having conflicting cuda extension issue on their system.

thanks!

@rabiaedayilmaz
Copy link
Contributor Author

Hi @ibinti, thank you so much! I was stuck at this problem, however, couldn't find any appropriate time to focus on this. This helped me! But... I have a problem with training 3d on Colab/Kaggle. I got this error, after resolving a bunch of them:

x = F.scaled_dot_product_attention(
AttributeError: module 'torch.nn.functional' has no attribute 'scaled_dot_product_attention'. Did you mean: '_scaled_dot_product_attention'?

Then I changed in sam2_train/modeling/backbones/hieradet.py the F._scaled_dot_product_attention since I got this error:

    x = F._scaled_dot_product_attention(
  File "/opt/conda/envs/medsam2/lib/python3.10/site-packages/torch/nn/functional.py", line 4848, in _scaled_dot_product_attention
    B, Nt, E = q.shape
ValueError: too many values to unpack (expected 3)

Do you have any idea?

@ibinti
Copy link

ibinti commented Sep 13, 2024

hello @rabiaedayilmaz ,
i did not use conda, as i've found it less beneficial on platforms like colab and kaggle compared to local machines. instead, i removed conda dependencies and installed the required packages directly using pip. this simplified the setup and made the code more straightforward. if you're interested, i can share a working colab notebook that demonstrates this approach. but i need to make some modification to the dataset path as i used my dataset on kaggle. let me know.

@rabiaedayilmaz
Copy link
Contributor Author

Hi @ibinti , I see.
After removing conda, it works properly. Thanks!

@rabiaedayilmaz rabiaedayilmaz changed the title add(notebook): example colab notebook to train&validate 2d data add(notebook): example colab notebook to train&validate 2d&3d data Sep 14, 2024
@rabiaedayilmaz rabiaedayilmaz changed the title add(notebook): example colab notebook to train&validate 2d&3d data add(notebook): example notebooks to train and predict Sep 14, 2024
@rabiaedayilmaz rabiaedayilmaz changed the title add(notebook): example notebooks to train and predict add: example notebooks to train and predict Sep 14, 2024
@rabiaedayilmaz rabiaedayilmaz mentioned this pull request Sep 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants