Skip to content

Commit

Permalink
HQ-SAM logit equal test, following #331
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Mar 23, 2024
1 parent 2763db9 commit 5c937b1
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ exclude_also = [

[tool.typos.default]
extend-words = { adaptee = "adaptee" }
extend-ignore-identifiers-re = ["NDArray*"]
extend-ignore-identifiers-re = ["NDArray*", "interm"]

[tool.pytest.ini_options]
filterwarnings = [
Expand Down
58 changes: 55 additions & 3 deletions tests/foundationals/segment_anything/test_hq_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
MaskDecoderTokensExtender,
PredictionsPostProc,
)
from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -248,8 +248,8 @@ def test_predictor(
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore

# NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3)
# See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426
# NOTE: Diff on logits is relatively high,
# see test_predictor_equal for a stricter version
assert torch.allclose(
reference_low_res_mask_hq,
refiners_low_res_mask_hq,
Expand All @@ -265,6 +265,58 @@ def test_predictor(
)


@pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor_equal(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image,
one_prompt: SAMPrompt,
) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()

adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only

# See in test_sam.py test_predictor_resized_single_output
# to do torch.equal we need to resize the image before
# and to use image_embedding as input

size = (1024, 1024)
resized_tennis = tennis.resize(size)

# Reference
reference_sam_h_predictor.set_image(np.array(resized_tennis))

predictor_prompt = one_prompt.__dict__["box_points"]
masks_np, _, low_res_masks_np = reference_sam_h_predictor.predict(
box=np.array(predictor_prompt).flatten(),
multimask_output=False,
hq_token_only=hq_mask_only,
)

reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore

# Refiners

# We bypass the refiners ViT by using directly the image features and interm_features
# from the reference implementation: this gives the ability to do torch.equal
reference_image_embedding = ImageEmbedding(features=reference_sam_h_predictor.features, original_image_size=size)
adapter.set_context("hq_sam", {"early_vit_embedding": reference_sam_h_predictor.interm_features[0]})

high_res_masks, _, low_res_masks = sam_h.predict(reference_image_embedding, **one_prompt.__dict__)
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()

assert torch.equal(
reference_low_res_mask_hq,
refiners_low_res_mask_hq,
)
assert torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() == 0


@no_grad()
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
Expand Down
2 changes: 2 additions & 0 deletions tests/foundationals/segment_anything/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def predict(

class FacebookSAMPredictorHQ:
model: FacebookSAM
features: Tensor
interm_features: Tensor

def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...

Expand Down

0 comments on commit 5c937b1

Please sign in to comment.