From b72a8a97f062eb918736925eb8942fb7a179c60e Mon Sep 17 00:00:00 2001 From: Niels Date: Sat, 3 Aug 2024 12:57:05 +0200 Subject: [PATCH 01/18] First draft --- sam2/build_sam.py | 16 +++++++++++++++- sam2/sam2_image_predictor.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 39defc46..e55f85e9 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -11,6 +11,8 @@ from hydra.utils import instantiate from omegaconf import OmegaConf +from huggingface_hub import hf_hub_download + def build_sam2( config_file, @@ -76,6 +78,18 @@ def build_sam2_video_predictor( return model +def build_sam2_hf(model_id, **kwargs): + config_file = hf_hub_download(repo_id=model_id, filename=f"{model_id}.yaml") + ckpt_path = hf_hub_download(repo_id=model_id, filename=f"{model_id}.pt") + return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs) + + +def build_sam2_video_predictor_hf(model_id, **kwargs): + config_file = hf_hub_download(repo_id=model_id, filename=f"{model_id}.yaml") + ckpt_path = hf_hub_download(repo_id=model_id, filename=f"{model_id}.pt") + return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs) + + def _load_checkpoint(model, ckpt_path): if ckpt_path is not None: sd = torch.load(ckpt_path, map_location="cpu")["model"] @@ -86,4 +100,4 @@ def _load_checkpoint(model, ckpt_path): if unexpected_keys: logging.error(unexpected_keys) raise RuntimeError() - logging.info("Loaded checkpoint sucessfully") + logging.info("Loaded checkpoint sucessfully") \ No newline at end of file diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 94111316..e7eebbef 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -13,7 +13,7 @@ from PIL.Image import Image from sam2.modeling.sam2_base import SAM2Base - +from sam2.build_sam import build_sam2_hf from sam2.utils.transforms import SAM2Transforms @@ -62,6 +62,20 @@ def __init__( (64, 64), ] + def from_pretrained(model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face model hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + sam_model = build_sam2_hf(model_id, **kwargs) + return SAM2ImagePredictor(sam_model) + @torch.no_grad() def set_image( self, From 17b74501fb41b5939ee4e7ecab1b71b2f9456ddb Mon Sep 17 00:00:00 2001 From: Niels Date: Sat, 3 Aug 2024 14:14:12 +0200 Subject: [PATCH 02/18] Use classmethod --- sam2/sam2_image_predictor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index e7eebbef..5d2980cb 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -62,7 +62,8 @@ def __init__( (64, 64), ] - def from_pretrained(model_id: str, **kwargs) -> "SAM2ImagePredictor": + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": """ Load a pretrained model from the Hugging Face model hub. @@ -74,7 +75,7 @@ def from_pretrained(model_id: str, **kwargs) -> "SAM2ImagePredictor": (SAM2ImagePredictor): The loaded model. """ sam_model = build_sam2_hf(model_id, **kwargs) - return SAM2ImagePredictor(sam_model) + return cls(sam_model) @torch.no_grad() def set_image( From 3af4e8226303abb2b865424a8d2b41a1c6dc3f78 Mon Sep 17 00:00:00 2001 From: Niels Date: Sat, 3 Aug 2024 14:18:23 +0200 Subject: [PATCH 03/18] Add model_id_to_filenames --- sam2/build_sam.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index e55f85e9..eb07ca1b 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -79,8 +79,16 @@ def build_sam2_video_predictor( def build_sam2_hf(model_id, **kwargs): - config_file = hf_hub_download(repo_id=model_id, filename=f"{model_id}.yaml") - ckpt_path = hf_hub_download(repo_id=model_id, filename=f"{model_id}.pt") + + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + config_file = hf_hub_download(repo_id=model_id, filename=config_name) + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs) From 0c28c630c20edff1fda3b4082378823bcd5720aa Mon Sep 17 00:00:00 2001 From: Niels Date: Sat, 3 Aug 2024 14:45:20 +0200 Subject: [PATCH 04/18] Do not load config from the hub --- sam2/build_sam.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index eb07ca1b..00f9dcf4 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -87,9 +87,9 @@ def build_sam2_hf(model_id, **kwargs): "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), } config_name, checkpoint_name = model_id_to_filenames[model_id] - config_file = hf_hub_download(repo_id=model_id, filename=config_name) + # config_file = hf_hub_download(repo_id=model_id, filename=config_name) ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) - return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs) + return build_sam2_video_predictor(config_file=config_name, ckpt_path=ckpt_path, **kwargs) def build_sam2_video_predictor_hf(model_id, **kwargs): From 6aeee347759c091e5718c4a03fe0032165564c50 Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 5 Aug 2024 09:37:53 +0200 Subject: [PATCH 05/18] Make huggingface_hub soft dependency --- sam2/build_sam.py | 8 +++++--- sam2/sam2_image_predictor.py | 3 ++- sam2/sam2_video_predictor.py | 17 +++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 00f9dcf4..9bb5279b 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -11,8 +11,6 @@ from hydra.utils import instantiate from omegaconf import OmegaConf -from huggingface_hub import hf_hub_download - def build_sam2( config_file, @@ -80,6 +78,8 @@ def build_sam2_video_predictor( def build_sam2_hf(model_id, **kwargs): + from huggingface_hub import hf_hub_download + model_id_to_filenames = { "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), @@ -87,12 +87,14 @@ def build_sam2_hf(model_id, **kwargs): "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), } config_name, checkpoint_name = model_id_to_filenames[model_id] - # config_file = hf_hub_download(repo_id=model_id, filename=config_name) ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) return build_sam2_video_predictor(config_file=config_name, ckpt_path=ckpt_path, **kwargs) def build_sam2_video_predictor_hf(model_id, **kwargs): + + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(repo_id=model_id, filename=f"{model_id}.yaml") ckpt_path = hf_hub_download(repo_id=model_id, filename=f"{model_id}.pt") return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs) diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 5d2980cb..9bee70db 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -13,7 +13,6 @@ from PIL.Image import Image from sam2.modeling.sam2_base import SAM2Base -from sam2.build_sam import build_sam2_hf from sam2.utils.transforms import SAM2Transforms @@ -74,6 +73,8 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": Returns: (SAM2ImagePredictor): The loaded model. """ + from sam2.build_sam import build_sam2_hf + sam_model = build_sam2_hf(model_id, **kwargs) return cls(sam_model) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 0defcecb..d687bc1d 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -103,6 +103,23 @@ def init_state( self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face model hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return cls(sam_model) + def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index.""" obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) From cb48213066bc3b23fea3711e177d54c5a8ae51b8 Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 5 Aug 2024 09:41:40 +0200 Subject: [PATCH 06/18] Update links --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index be937400..7e227a7d 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ cd checkpoints or individually from: -- [sam2_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt) -- [sam2_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt) -- [sam2_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt) -- [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt) +- [sam2_hiera_tiny.pt](https://huggingface.co/facebook/sam2-hiera-tiny) +- [sam2_hiera_small.pt](https://huggingface.co/facebook/sam2-hiera-small) +- [sam2_hiera_base_plus.pt](https://huggingface.co/facebook/sam2-hiera-base-plus) +- [sam2_hiera_large.pt](https://huggingface.co/facebook/sam2-hiera-large) Then SAM 2 can be used in a few lines as follows for image and video prediction. From e93be7f6aa82353935aafc7c8db84df6dabd2945 Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 5 Aug 2024 09:43:04 +0200 Subject: [PATCH 07/18] Update README --- README.md | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 7e227a7d..21104c77 100644 --- a/README.md +++ b/README.md @@ -60,12 +60,9 @@ SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segm ```python import torch -from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor -checkpoint = "./checkpoints/sam2_hiera_large.pt" -model_cfg = "sam2_hiera_l.yaml" -predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint)) +predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image() @@ -82,11 +79,9 @@ For promptable segmentation and tracking in videos, we provide a video predictor ```python import torch -from sam2.build_sam import build_sam2_video_predictor +from sam2.sam2_video_predictor import SAM2VideoPredictor -checkpoint = "./checkpoints/sam2_hiera_large.pt" -model_cfg = "sam2_hiera_l.yaml" -predictor = build_sam2_video_predictor(model_cfg, checkpoint) +predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state() From 841cc1f0154adf4fd10e0bf0fbbdede62c6698e7 Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 5 Aug 2024 09:44:03 +0200 Subject: [PATCH 08/18] Update docstring --- sam2/sam2_video_predictor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index d687bc1d..3a751c57 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -104,7 +104,7 @@ def init_state( return inference_state @classmethod - def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": """ Load a pretrained model from the Hugging Face model hub. @@ -113,7 +113,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": **kwargs: Additional arguments to pass to the model constructor. Returns: - (SAM2ImagePredictor): The loaded model. + (SAM2VideoPredictor): The loaded model. """ from sam2.build_sam import build_sam2_video_predictor_hf From c3393d8b5f9006eb658ef88bcf68bb68bf9776b5 Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 5 Aug 2024 22:08:54 +0200 Subject: [PATCH 09/18] Include original code snippet --- README.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/README.md b/README.md index 21104c77..bdc88e24 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,22 @@ Then SAM 2 can be used in a few lines as follows for image and video prediction. SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting. +```python +import torch +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +checkpoint = "./checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" +predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint)) + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +or from Hugging Face, as follows: + ```python import torch from sam2.sam2_image_predictor import SAM2ImagePredictor @@ -94,6 +110,19 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): ... ``` +or from Hugging Face, as follows: + +```python +import torch +from sam2.sam2_video_predictor import SAM2VideoPredictor + +predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos. ## Model Description From e9503c96fe7c60c529ffab3d005af753466230ef Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 5 Aug 2024 22:10:57 +0200 Subject: [PATCH 10/18] Move HF to separate section --- README.md | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index bdc88e24..6328e1fd 100644 --- a/README.md +++ b/README.md @@ -72,19 +72,6 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): masks, _, _ = predictor.predict() ``` -or from Hugging Face, as follows: - -```python -import torch -from sam2.sam2_image_predictor import SAM2ImagePredictor - -predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") - -with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): - predictor.set_image() - masks, _, _ = predictor.predict() -``` - Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) for static image use cases. SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) for automatic mask generation in images. @@ -110,7 +97,26 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): ... ``` -or from Hugging Face, as follows: +Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos. + +## Load from Hugging Face + +Alternatively, models can also be loaded from Hugging Face using the `from_pretrained` method: + +For image prediction: + +```python +import torch +from sam2.sam2_image_predictor import SAM2ImagePredictor + +predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") + +with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): + predictor.set_image() + masks, _, _ = predictor.predict() +``` + +For video prediction: ```python import torch @@ -123,8 +129,6 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): masks, _, _ = predictor.predict() ``` -Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos. - ## Model Description | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | From fbf7e3a664eb4b158715b666792793914d64134f Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 5 Aug 2024 22:12:15 +0200 Subject: [PATCH 11/18] Add link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6328e1fd..5aed1d0d 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ Please refer to the examples in [video_predictor_example.ipynb](./notebooks/vide ## Load from Hugging Face -Alternatively, models can also be loaded from Hugging Face using the `from_pretrained` method: +Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`). For image prediction: From e815f70a3805ba9e7f0198b0ab1c193515d209ef Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 6 Aug 2024 08:32:36 +0200 Subject: [PATCH 12/18] Address comment --- sam2/build_sam.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 9bb5279b..a0617311 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -88,16 +88,22 @@ def build_sam2_hf(model_id, **kwargs): } config_name, checkpoint_name = model_id_to_filenames[model_id] ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) - return build_sam2_video_predictor(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) def build_sam2_video_predictor_hf(model_id, **kwargs): from huggingface_hub import hf_hub_download - config_file = hf_hub_download(repo_id=model_id, filename=f"{model_id}.yaml") - ckpt_path = hf_hub_download(repo_id=model_id, filename=f"{model_id}.pt") - return build_sam2_video_predictor(config_file=config_file, ckpt_path=ckpt_path, **kwargs) + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return build_sam2_video_predictor(config_file=config_name, ckpt_path=ckpt_path, **kwargs) def _load_checkpoint(model, ckpt_path): From a36edf1e019913cf4e87924cf92b82698087b045 Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 6 Aug 2024 08:34:42 +0200 Subject: [PATCH 13/18] Clean up --- sam2/sam2_image_predictor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 9bee70db..6fde8831 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -13,6 +13,7 @@ from PIL.Image import Image from sam2.modeling.sam2_base import SAM2Base + from sam2.utils.transforms import SAM2Transforms From 27a167c00424022bc4af272db49879ba726b2765 Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 6 Aug 2024 22:41:32 +0200 Subject: [PATCH 14/18] Update README --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 5aed1d0d..26e8ee59 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ cd checkpoints or individually from: -- [sam2_hiera_tiny.pt](https://huggingface.co/facebook/sam2-hiera-tiny) -- [sam2_hiera_small.pt](https://huggingface.co/facebook/sam2-hiera-small) -- [sam2_hiera_base_plus.pt](https://huggingface.co/facebook/sam2-hiera-base-plus) -- [sam2_hiera_large.pt](https://huggingface.co/facebook/sam2-hiera-large) +- [sam2_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt) +- [sam2_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt) +- [sam2_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt) +- [sam2_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt) Then SAM 2 can be used in a few lines as follows for image and video prediction. @@ -99,7 +99,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos. -## Load from Hugging Face +## Load from 🤗 Hugging Face Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`). From 8f15c6255a51b687f4cbaac5836c41d881c1299d Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 6 Aug 2024 22:43:35 +0200 Subject: [PATCH 15/18] Format using ufmt --- sam2/build_sam.py | 16 ++++++++++++---- sam2/sam2_video_predictor.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index a0617311..e5911d49 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -83,7 +83,10 @@ def build_sam2_hf(model_id, **kwargs): model_id_to_filenames = { "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), - "facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), } config_name, checkpoint_name = model_id_to_filenames[model_id] @@ -98,12 +101,17 @@ def build_sam2_video_predictor_hf(model_id, **kwargs): model_id_to_filenames = { "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), - "facebook/sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "sam2_hiera_base_plus.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), } config_name, checkpoint_name = model_id_to_filenames[model_id] ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) - return build_sam2_video_predictor(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + return build_sam2_video_predictor( + config_file=config_name, ckpt_path=ckpt_path, **kwargs + ) def _load_checkpoint(model, ckpt_path): @@ -116,4 +124,4 @@ def _load_checkpoint(model, ckpt_path): if unexpected_keys: logging.error(unexpected_keys) raise RuntimeError() - logging.info("Loaded checkpoint sucessfully") \ No newline at end of file + logging.info("Loaded checkpoint sucessfully") diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 3a751c57..6e27efbd 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -119,7 +119,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) return cls(sam_model) - + def _obj_id_to_idx(self, inference_state, obj_id): """Map client-side object id to model-side object index.""" obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) From 322aa3e7e55958b86161239035c773364575a387 Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 6 Aug 2024 22:57:07 +0200 Subject: [PATCH 16/18] Revert code snippet --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 26e8ee59..caf1024c 100644 --- a/README.md +++ b/README.md @@ -82,9 +82,11 @@ For promptable segmentation and tracking in videos, we provide a video predictor ```python import torch -from sam2.sam2_video_predictor import SAM2VideoPredictor +from sam2.build_sam import build_sam2_video_predictor -predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") +checkpoint = "./checkpoints/sam2_hiera_large.pt" +model_cfg = "sam2_hiera_l.yaml" +predictor = build_sam2_video_predictor(model_cfg, checkpoint) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state() From 43c385c26327d1d29429750430720776aa1fcbee Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 6 Aug 2024 23:00:26 +0200 Subject: [PATCH 17/18] Update docstrings --- sam2/sam2_image_predictor.py | 2 +- sam2/sam2_video_predictor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 6fde8831..f6f9a5a1 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -65,7 +65,7 @@ def __init__( @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": """ - Load a pretrained model from the Hugging Face model hub. + Load a pretrained model from the Hugging Face hub. Arguments: model_id (str): The Hugging Face repository ID. diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 6e27efbd..fe8702b6 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -106,7 +106,7 @@ def init_state( @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": """ - Load a pretrained model from the Hugging Face model hub. + Load a pretrained model from the Hugging Face hub. Arguments: model_id (str): The Hugging Face repository ID. From 9b58611e24543966167f17a2003c60f45be97121 Mon Sep 17 00:00:00 2001 From: Niels Date: Wed, 7 Aug 2024 17:48:12 +0200 Subject: [PATCH 18/18] Address comment --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index caf1024c..cee15318 100644 --- a/README.md +++ b/README.md @@ -127,8 +127,14 @@ from sam2.sam2_video_predictor import SAM2VideoPredictor predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): - predictor.set_image() - masks, _, _ = predictor.predict() + state = predictor.init_state() + + # add new prompts and instantly get the output on the same frame + frame_idx, object_ids, masks = predictor.add_new_points(state, ): + + # propagate the prompts to get masklets throughout the video + for frame_idx, object_ids, masks in predictor.propagate_in_video(state): + ... ``` ## Model Description