From 1c389a059ec9bd33bfb9eaf38430899e4a5e67bb Mon Sep 17 00:00:00 2001 From: Haibao Tang Date: Sun, 19 May 2024 23:32:57 -0700 Subject: [PATCH] Ask where the sam model is --- jcvi/graphics/grabseeds.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/jcvi/graphics/grabseeds.py b/jcvi/graphics/grabseeds.py index 2d15a93b..800ef001 100644 --- a/jcvi/graphics/grabseeds.py +++ b/jcvi/graphics/grabseeds.py @@ -166,12 +166,13 @@ def sam(img: np.ndarray) -> List[dict]: sys.exit(1) model_type = "vit_h" - checkpoint = "./sam_vit_h_4b8939.pth" + checkpoint = "sam_vit_h_4b8939.pth" if not op.exists(checkpoint): - checkpoint = input( - "Enter the path to the SAM model checkpoint [./sam_vit_h_4b8939.pth]: " - ) + checkpoint_dir = input("Enter the path to `sam_vit_h_4b8939.pth`: ") + checkpoint = op.join(checkpoint_dir, "sam_vit_h_4b8939.pth") + assert op.exists(checkpoint), f"File `{checkpoint}` not found" sam = sam_model_registry[model_type](checkpoint=checkpoint) + logger.info("Using SAM model `%s` (%s)", model_type, checkpoint) mask_generator = SamAutomaticMaskGenerator(sam) return mask_generator.generate(img) @@ -722,6 +723,7 @@ def seeds(args): labels = np.zeros(img_gray.shape, dtype=int) for i, mask in enumerate(deduplicated_masks): labels[mask["segmentation"]] = i + 1 + closed = None else: edges = clear_border(edges, buffer_size=opts.border) selem = disk(kernel) @@ -756,7 +758,8 @@ def seeds(args): if opts.watershed: params += ", watershed" ax2.set_title(f"Edge detection\n({params})") - closed = gray2rgb(closed) + if closed: + closed = gray2rgb(closed) ax2_img = labels if opts.edges: ax2_img = closed