From d05ebb8dd30f314adc0a7124888e572265f0e4f7 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Wed, 3 Apr 2024 10:19:15 +0000 Subject: [PATCH] SAM/HQSAMAdapter: docstring examples --- .../segment_anything/__init__.py | 3 +- .../foundationals/segment_anything/hq_sam.py | 12 ++++++++ .../foundationals/segment_anything/model.py | 28 +++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/refiners/foundationals/segment_anything/__init__.py b/src/refiners/foundationals/segment_anything/__init__.py index 5f1d54d17..8ce3045cf 100644 --- a/src/refiners/foundationals/segment_anything/__init__.py +++ b/src/refiners/foundationals/segment_anything/__init__.py @@ -1,3 +1,4 @@ +from refiners.foundationals.segment_anything.hq_sam import HQSAMAdapter from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH -__all__ = ["SegmentAnything", "SegmentAnythingH"] +__all__ = ["SegmentAnything", "SegmentAnythingH", "HQSAMAdapter"] diff --git a/src/refiners/foundationals/segment_anything/hq_sam.py b/src/refiners/foundationals/segment_anything/hq_sam.py index f87e1e01b..87129f5a5 100644 --- a/src/refiners/foundationals/segment_anything/hq_sam.py +++ b/src/refiners/foundationals/segment_anything/hq_sam.py @@ -291,6 +291,18 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]): """Adapter for SAM introducing HQ features. See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details. + + Example: + ```py + from refiners.fluxion.utils import load_from_safetensors + + # Tips: run scripts/prepare_test_weights.py to download the weights + tensor_path = "./tests/weights/refiners-sam-hq-vit-h.safetensors" + weights = load_from_safetensors(tensor_path) + + hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights) + hq_sam_adapter.inject() # then use SAM as usual + ``` """ _adapter_modules: dict[str, fl.Module] = {} diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 48126f6e2..1f94e7021 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -25,6 +25,8 @@ class SegmentAnything(fl.Chain): See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643) + E.g. see [`SegmentAnythingH`][refiners.foundationals.segment_anything.model.SegmentAnythingH] for usage. + Attributes: mask_threshold (float): 0.0 """ @@ -262,6 +264,32 @@ def __init__( multimask_output: Whether to use multimask output. device: The PyTorch device to use. dtype: The PyTorch data type to use. + + Example: + ```py + device="cuda" if torch.cuda.is_available() else "cpu" + + # multimask_output=True is recommended for ambiguous prompts such as a single point. + # Below, a box prompt is passed, so just use multimask_output=False which will return a single mask + sam_h = SegmentAnythingH(multimask_output=False, device=device) + + # Tips: run scripts/prepare_test_weights.py to download the weights + tensors_path = "./tests/weights/segment-anything-h.safetensors" + sam_h.load_from_safetensors(tensors_path=tensors_path) + + from PIL import Image + image = Image.open("image.png") + + masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]]) + + assert masks.shape == (1, 1, image.height, image.width) + assert masks.dtype == torch.bool + + # convert it to [0,255] uint8 ndarray of shape (H, W) + mask = masks[0, 0].cpu().numpy().astype("uint8") * 255 + + Image.fromarray(mask).save("mask_image.png") + ``` """ image_encoder = image_encoder or SAMViTH() point_encoder = point_encoder or PointEncoder()