diff --git a/jcvi/graphics/grabseeds.py b/jcvi/graphics/grabseeds.py index 489f5063..2d15a93b 100644 --- a/jcvi/graphics/grabseeds.py +++ b/jcvi/graphics/grabseeds.py @@ -155,6 +155,57 @@ def calibrate(self, pixel_cm_ratio: float, tr: np.ndarray): self.calibrated = True +def sam(img: np.ndarray) -> List[dict]: + """ + Use Segment Anything Model (SAM) to segment objects. + """ + try: + from segment_anything import sam_model_registry, SamAutomaticMaskGenerator + except ImportError: + logger.fatal("segment_anything not installed. Please install it first.") + sys.exit(1) + + model_type = "vit_h" + checkpoint = "./sam_vit_h_4b8939.pth" + if not op.exists(checkpoint): + checkpoint = input( + "Enter the path to the SAM model checkpoint [./sam_vit_h_4b8939.pth]: " + ) + sam = sam_model_registry[model_type](checkpoint=checkpoint) + mask_generator = SamAutomaticMaskGenerator(sam) + return mask_generator.generate(img) + + +def is_overlapping(mask1: dict, mask2: dict, threshold=0.5): + """ + Check if bounding boxes of mask1 and mask2 overlap more than the given + threshold. + """ + x1, y1, w1, h1 = mask1["bbox"] + x2, y2, w2, h2 = mask2["bbox"] + x_overlap = max(0, min(x1 + w1, x2 + w2) - max(x1, x2)) + y_overlap = max(0, min(y1 + h1, y2 + h2) - max(y1, y2)) + intersection = x_overlap * y_overlap + return intersection / min(w1 * h1, w2 * h2) > threshold + + +def deduplicate_masks(masks: List[dict], threshold=0.5): + """ + Deduplicate masks to retain only the foreground objects. + """ + masks_sorted = sorted(masks, key=lambda x: x["area"]) + retained_masks = [] + + for mask in masks_sorted: + if not any( + is_overlapping(mask, retained_mask, threshold) + for retained_mask in retained_masks + ): + retained_masks.append(mask) + logger.info("Retained %d masks out of %d", len(retained_masks), len(masks)) + return retained_masks + + def rgb_to_triplet(rgb: str) -> RGBTuple: """ Convert RGB string to triplet. @@ -313,7 +364,7 @@ def add_seeds_options(p, args): ) g3 = p.add_argument_group("De-noise") - valid_filters = ("canny", "roberts", "sobel", "otsu") + valid_filters = ("canny", "otsu", "roberts", "sam", "sobel") g3.add_argument( "--filter", default="canny", @@ -647,37 +698,48 @@ def seeds(args): _, (ax1, ax2, ax3, ax4) = plt.subplots(ncols=4, nrows=1, figsize=(iopts.w, iopts.h)) # Edge detection img_gray = rgb2gray(img) + w, h = img_gray.shape + canvas_size = w * h + min_size = int(round(canvas_size * opts.minsize / 100)) + max_size = int(round(canvas_size * opts.maxsize / 100)) + logger.debug("Running %s edge detection ...", ff) if ff == "canny": edges = canny(img_gray, sigma=opts.sigma) + elif ff == "otsu": + thresh = threshold_otsu(img_gray) + edges = img_gray > thresh elif ff == "roberts": edges = roberts(img_gray) elif ff == "sobel": edges = sobel(img_gray) - elif ff == "otsu": - thresh = threshold_otsu(img_gray) - edges = img_gray > thresh - edges = clear_border(edges, buffer_size=opts.border) - selem = disk(kernel) - closed = closing(edges, selem) if kernel else edges - filled = binary_fill_holes(closed) - - # Watershed algorithm - if opts.watershed: - distance = distance_transform_edt(filled) - local_maxi = peak_local_max(distance, threshold_rel=0.05, indices=False) - coordinates = peak_local_max(distance, threshold_rel=0.05) - markers, nmarkers = label(local_maxi, return_num=True) - logger.debug("Identified %d watershed markers", nmarkers) - labels = watershed(closed, markers, mask=filled) + if ff == "sam": + masks = sam(img) + filtered_masks = [ + mask for mask in masks if min_size <= mask["area"] <= max_size + ] + deduplicated_masks = deduplicate_masks(filtered_masks) + labels = np.zeros(img_gray.shape, dtype=int) + for i, mask in enumerate(deduplicated_masks): + labels[mask["segmentation"]] = i + 1 else: - labels = label(filled) + edges = clear_border(edges, buffer_size=opts.border) + selem = disk(kernel) + closed = closing(edges, selem) if kernel else edges + filled = binary_fill_holes(closed) + + # Watershed algorithm + if opts.watershed: + distance = distance_transform_edt(filled) + local_maxi = peak_local_max(distance, threshold_rel=0.05, indices=False) + coordinates = peak_local_max(distance, threshold_rel=0.05) + markers, nmarkers = label(local_maxi, return_num=True) + logger.debug("Identified %d watershed markers", nmarkers) + labels = watershed(closed, markers, mask=filled) + else: + labels = label(filled) # Object size filtering - w, h = img_gray.shape - canvas_size = w * h - min_size = int(round(canvas_size * opts.minsize / 100)) - max_size = int(round(canvas_size * opts.maxsize / 100)) logger.debug( "Find objects with pixels between %d (%d%%) and %d (%d%%)", min_size,