Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate with Hugging Face #128

Merged
merged 18 commits into from
Aug 7, 2024
36 changes: 36 additions & 0 deletions README.md
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,42 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):

Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.

## Load from 🤗 Hugging Face

Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`).

For image prediction:

```python
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor

predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)
```

For video prediction:

```python
import torch
from sam2.sam2_video_predictor import SAM2VideoPredictor

predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
state = predictor.init_state(<your_video>)

# add new prompts and instantly get the output on the same frame
frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):

# propagate the prompts to get masklets throughout the video
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
...
```

## Model Description

| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
Expand Down
38 changes: 38 additions & 0 deletions sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,44 @@ def build_sam2_video_predictor(
return model


def build_sam2_hf(model_id, **kwargs):

from huggingface_hub import hf_hub_download

model_id_to_filenames = {
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
"facebook/sam2-hiera-base-plus": (
"sam2_hiera_b+.yaml",
"sam2_hiera_base_plus.pt",
),
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
}
config_name, checkpoint_name = model_id_to_filenames[model_id]
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)


def build_sam2_video_predictor_hf(model_id, **kwargs):

from huggingface_hub import hf_hub_download

model_id_to_filenames = {
"facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
"facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
"facebook/sam2-hiera-base-plus": (
"sam2_hiera_b+.yaml",
"sam2_hiera_base_plus.pt",
),
"facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
}
config_name, checkpoint_name = model_id_to_filenames[model_id]
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
return build_sam2_video_predictor(
config_file=config_name, ckpt_path=ckpt_path, **kwargs
)


def _load_checkpoint(model, ckpt_path):
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["model"]
Expand Down
17 changes: 17 additions & 0 deletions sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ def __init__(
(64, 64),
]

@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
"""
Load a pretrained model from the Hugging Face hub.

Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.

Returns:
(SAM2ImagePredictor): The loaded model.
"""
from sam2.build_sam import build_sam2_hf

sam_model = build_sam2_hf(model_id, **kwargs)
return cls(sam_model)

@torch.no_grad()
def set_image(
self,
Expand Down
17 changes: 17 additions & 0 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ def init_state(
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state

@classmethod
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
"""
Load a pretrained model from the Hugging Face hub.

Arguments:
model_id (str): The Hugging Face repository ID.
**kwargs: Additional arguments to pass to the model constructor.

Returns:
(SAM2VideoPredictor): The loaded model.
"""
from sam2.build_sam import build_sam2_video_predictor_hf

sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
return cls(sam_model)

def _obj_id_to_idx(self, inference_state, obj_id):
"""Map client-side object id to model-side object index."""
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
Expand Down
Loading