Skip to content

Commit

Permalink
Ask where the sam model is
Browse files Browse the repository at this point in the history
  • Loading branch information
tanghaibao committed May 20, 2024
1 parent 6832589 commit 1c389a0
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions jcvi/graphics/grabseeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,13 @@ def sam(img: np.ndarray) -> List[dict]:
sys.exit(1)

model_type = "vit_h"
checkpoint = "./sam_vit_h_4b8939.pth"
checkpoint = "sam_vit_h_4b8939.pth"
if not op.exists(checkpoint):
checkpoint = input(
"Enter the path to the SAM model checkpoint [./sam_vit_h_4b8939.pth]: "
)
checkpoint_dir = input("Enter the path to `sam_vit_h_4b8939.pth`: ")
checkpoint = op.join(checkpoint_dir, "sam_vit_h_4b8939.pth")
assert op.exists(checkpoint), f"File `{checkpoint}` not found"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
logger.info("Using SAM model `%s` (%s)", model_type, checkpoint)
mask_generator = SamAutomaticMaskGenerator(sam)
return mask_generator.generate(img)

Expand Down Expand Up @@ -722,6 +723,7 @@ def seeds(args):
labels = np.zeros(img_gray.shape, dtype=int)
for i, mask in enumerate(deduplicated_masks):
labels[mask["segmentation"]] = i + 1
closed = None
else:
edges = clear_border(edges, buffer_size=opts.border)
selem = disk(kernel)
Expand Down Expand Up @@ -756,7 +758,8 @@ def seeds(args):
if opts.watershed:
params += ", watershed"
ax2.set_title(f"Edge detection\n({params})")
closed = gray2rgb(closed)
if closed:
closed = gray2rgb(closed)
ax2_img = labels
if opts.edges:
ax2_img = closed
Expand Down

0 comments on commit 1c389a0

Please sign in to comment.