Skip to content

Commit

Permalink
Add sam
Browse files Browse the repository at this point in the history
  • Loading branch information
tanghaibao committed May 20, 2024
1 parent 6f6d556 commit 6832589
Showing 1 changed file with 84 additions and 22 deletions.
106 changes: 84 additions & 22 deletions jcvi/graphics/grabseeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,57 @@ def calibrate(self, pixel_cm_ratio: float, tr: np.ndarray):
self.calibrated = True


def sam(img: np.ndarray) -> List[dict]:
"""
Use Segment Anything Model (SAM) to segment objects.
"""
try:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
except ImportError:
logger.fatal("segment_anything not installed. Please install it first.")
sys.exit(1)

model_type = "vit_h"
checkpoint = "./sam_vit_h_4b8939.pth"
if not op.exists(checkpoint):
checkpoint = input(
"Enter the path to the SAM model checkpoint [./sam_vit_h_4b8939.pth]: "
)
sam = sam_model_registry[model_type](checkpoint=checkpoint)
mask_generator = SamAutomaticMaskGenerator(sam)
return mask_generator.generate(img)


def is_overlapping(mask1: dict, mask2: dict, threshold=0.5):
"""
Check if bounding boxes of mask1 and mask2 overlap more than the given
threshold.
"""
x1, y1, w1, h1 = mask1["bbox"]
x2, y2, w2, h2 = mask2["bbox"]
x_overlap = max(0, min(x1 + w1, x2 + w2) - max(x1, x2))
y_overlap = max(0, min(y1 + h1, y2 + h2) - max(y1, y2))
intersection = x_overlap * y_overlap
return intersection / min(w1 * h1, w2 * h2) > threshold


def deduplicate_masks(masks: List[dict], threshold=0.5):
"""
Deduplicate masks to retain only the foreground objects.
"""
masks_sorted = sorted(masks, key=lambda x: x["area"])
retained_masks = []

for mask in masks_sorted:
if not any(
is_overlapping(mask, retained_mask, threshold)
for retained_mask in retained_masks
):
retained_masks.append(mask)
logger.info("Retained %d masks out of %d", len(retained_masks), len(masks))
return retained_masks


def rgb_to_triplet(rgb: str) -> RGBTuple:
"""
Convert RGB string to triplet.
Expand Down Expand Up @@ -313,7 +364,7 @@ def add_seeds_options(p, args):
)

g3 = p.add_argument_group("De-noise")
valid_filters = ("canny", "roberts", "sobel", "otsu")
valid_filters = ("canny", "otsu", "roberts", "sam", "sobel")
g3.add_argument(
"--filter",
default="canny",
Expand Down Expand Up @@ -647,37 +698,48 @@ def seeds(args):
_, (ax1, ax2, ax3, ax4) = plt.subplots(ncols=4, nrows=1, figsize=(iopts.w, iopts.h))
# Edge detection
img_gray = rgb2gray(img)
w, h = img_gray.shape
canvas_size = w * h
min_size = int(round(canvas_size * opts.minsize / 100))
max_size = int(round(canvas_size * opts.maxsize / 100))

logger.debug("Running %s edge detection ...", ff)
if ff == "canny":
edges = canny(img_gray, sigma=opts.sigma)
elif ff == "otsu":
thresh = threshold_otsu(img_gray)
edges = img_gray > thresh
elif ff == "roberts":
edges = roberts(img_gray)
elif ff == "sobel":
edges = sobel(img_gray)
elif ff == "otsu":
thresh = threshold_otsu(img_gray)
edges = img_gray > thresh
edges = clear_border(edges, buffer_size=opts.border)
selem = disk(kernel)
closed = closing(edges, selem) if kernel else edges
filled = binary_fill_holes(closed)

# Watershed algorithm
if opts.watershed:
distance = distance_transform_edt(filled)
local_maxi = peak_local_max(distance, threshold_rel=0.05, indices=False)
coordinates = peak_local_max(distance, threshold_rel=0.05)
markers, nmarkers = label(local_maxi, return_num=True)
logger.debug("Identified %d watershed markers", nmarkers)
labels = watershed(closed, markers, mask=filled)
if ff == "sam":
masks = sam(img)
filtered_masks = [
mask for mask in masks if min_size <= mask["area"] <= max_size
]
deduplicated_masks = deduplicate_masks(filtered_masks)
labels = np.zeros(img_gray.shape, dtype=int)
for i, mask in enumerate(deduplicated_masks):
labels[mask["segmentation"]] = i + 1
else:
labels = label(filled)
edges = clear_border(edges, buffer_size=opts.border)
selem = disk(kernel)
closed = closing(edges, selem) if kernel else edges
filled = binary_fill_holes(closed)

# Watershed algorithm
if opts.watershed:
distance = distance_transform_edt(filled)
local_maxi = peak_local_max(distance, threshold_rel=0.05, indices=False)
coordinates = peak_local_max(distance, threshold_rel=0.05)
markers, nmarkers = label(local_maxi, return_num=True)
logger.debug("Identified %d watershed markers", nmarkers)
labels = watershed(closed, markers, mask=filled)
else:
labels = label(filled)

# Object size filtering
w, h = img_gray.shape
canvas_size = w * h
min_size = int(round(canvas_size * opts.minsize / 100))
max_size = int(round(canvas_size * opts.maxsize / 100))
logger.debug(
"Find objects with pixels between %d (%d%%) and %d (%d%%)",
min_size,
Expand Down

0 comments on commit 6832589

Please sign in to comment.