diff --git a/README.md b/README.md index be937400..cee15318 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,42 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos. +## Load from 🤗 Hugging Face + +Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`). + +For image prediction: + +```python +import torch +from sam2.sam2_image_predictor import SAM2ImagePredictor + +predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +For video prediction: + +```python +import torch +from sam2.sam2_video_predictor import SAM2VideoPredictor + +predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + state = predictor.init_state() + + # add new prompts and instantly get the output on the same frame + frame_idx, object_ids, masks = predictor.add_new_points(state, ): + + # propagate the prompts to get masklets throughout the video + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + ... +``` + ## Model Description | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 39defc46..e5911d49 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -76,6 +76,44 @@ def build_sam2_video_predictor( return model +def build_sam2_hf(model_id, **kwargs): + + from huggingface_hub import hf_hub_download + + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +def build_sam2_video_predictor_hf(model_id, **kwargs): + + from huggingface_hub import hf_hub_download + + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return build_sam2_video_predictor( + config_file=config_name, ckpt_path=ckpt_path, **kwargs + ) + + def _load_checkpoint(model, ckpt_path): if ckpt_path is not None: sd = torch.load(ckpt_path, map_location="cpu")["model"] diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 94111316..f6f9a5a1 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -62,6 +62,23 @@ def __init__( (64, 64), ] + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model) + @torch.no_grad() def set_image( self, diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 0defcecb..fe8702b6 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -103,6 +103,23 @@ def init_state( self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return cls(sam_model) + def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index.""" obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)