Skip to content

Commit

Permalink
SAM/HQSAMAdapter: docstring examples
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Apr 8, 2024
1 parent e033306 commit d05ebb8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/refiners/foundationals/segment_anything/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from refiners.foundationals.segment_anything.hq_sam import HQSAMAdapter
from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH

__all__ = ["SegmentAnything", "SegmentAnythingH"]
__all__ = ["SegmentAnything", "SegmentAnythingH", "HQSAMAdapter"]
12 changes: 12 additions & 0 deletions src/refiners/foundationals/segment_anything/hq_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,18 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
"""Adapter for SAM introducing HQ features.
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
Example:
```py
from refiners.fluxion.utils import load_from_safetensors
# Tips: run scripts/prepare_test_weights.py to download the weights
tensor_path = "./tests/weights/refiners-sam-hq-vit-h.safetensors"
weights = load_from_safetensors(tensor_path)
hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)
hq_sam_adapter.inject() # then use SAM as usual
```
"""

_adapter_modules: dict[str, fl.Module] = {}
Expand Down
28 changes: 28 additions & 0 deletions src/refiners/foundationals/segment_anything/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class SegmentAnything(fl.Chain):
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
E.g. see [`SegmentAnythingH`][refiners.foundationals.segment_anything.model.SegmentAnythingH] for usage.
Attributes:
mask_threshold (float): 0.0
"""
Expand Down Expand Up @@ -262,6 +264,32 @@ def __init__(
multimask_output: Whether to use multimask output.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
Example:
```py
device="cuda" if torch.cuda.is_available() else "cpu"
# multimask_output=True is recommended for ambiguous prompts such as a single point.
# Below, a box prompt is passed, so just use multimask_output=False which will return a single mask
sam_h = SegmentAnythingH(multimask_output=False, device=device)
# Tips: run scripts/prepare_test_weights.py to download the weights
tensors_path = "./tests/weights/segment-anything-h.safetensors"
sam_h.load_from_safetensors(tensors_path=tensors_path)
from PIL import Image
image = Image.open("image.png")
masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])
assert masks.shape == (1, 1, image.height, image.width)
assert masks.dtype == torch.bool
# convert it to [0,255] uint8 ndarray of shape (H, W)
mask = masks[0, 0].cpu().numpy().astype("uint8") * 255
Image.fromarray(mask).save("mask_image.png")
```
"""
image_encoder = image_encoder or SAMViTH()
point_encoder = point_encoder or PointEncoder()
Expand Down

0 comments on commit d05ebb8

Please sign in to comment.