Skip to content

Commit

Permalink
Add logits comparison for base SAM in single mask output prediction mode
Browse files Browse the repository at this point in the history
  • Loading branch information
hugojarkoff committed Mar 21, 2024
1 parent 38c86f5 commit c6b5eb2
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/foundationals/segment_anything/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,23 +407,29 @@ def test_predictor_single_output(
predictor = facebook_sam_h_predictor
predictor.set_image(np.array(truck))

facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore
facebook_masks, facebook_scores, facebook_low_res_masks = predictor.predict( # type: ignore
**one_prompt.facebook_predict_kwargs(), # type: ignore
multimask_output=False,
)

assert len(facebook_masks) == 1

masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__)
masks, scores, low_res_masks = sam_h_single_output.predict(truck, **one_prompt.__dict__)
masks = masks.squeeze(0)
scores = scores.squeeze(0)

assert len(masks) == 1

assert torch.allclose(
low_res_masks[0, 0, ...],
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation
)
assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)

mask_prediction = masks[0].cpu()
facebook_mask = torch.as_tensor(facebook_masks[0])
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05)


def test_mask_encoder(
Expand Down

0 comments on commit c6b5eb2

Please sign in to comment.