Skip to content

Commit

Permalink
Add multimask_output flag to SAM
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Mar 19, 2024
1 parent 6a72943 commit 68fe725
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 19 deletions.
48 changes: 30 additions & 18 deletions src/refiners/foundationals/segment_anything/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@


class EmbeddingsAggregator(fl.ContextModule):
def __init__(self, num_output_mask: int = 3) -> None:
super().__init__()
self.num_mask_tokens = num_output_mask

def forward(self, iou_mask_tokens: Tensor) -> Tensor:
mask_decoder = self.ensure_parent
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
Expand Down Expand Up @@ -48,7 +44,7 @@ def __init__(
self,
embedding_dim: int = 256,
num_layers: int = 3,
num_mask_tokens: int = 3,
num_mask_tokens: int = 4,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
Expand All @@ -70,7 +66,7 @@ def __init__(
dtype=dtype,
),
)
for i in range(num_mask_tokens + 1)
for i in range(num_mask_tokens)
],
dim=1,
)
Expand Down Expand Up @@ -138,13 +134,18 @@ def __init__(
self,
embedding_dim: int,
num_mask_tokens: int,
multimask_output: bool,
num_layers: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_mask_tokens
self.num_layers = num_layers
self.multimask_output = multimask_output

start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1)

super().__init__(
fl.Matmul(
input=Hypernetworks(
Expand All @@ -156,8 +157,8 @@ def __init__(
),
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
),
fl.Slicing(dim=1, start=1),
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
fl.Slicing(dim=1, start=start_mask, end=start_mask + num_masks),
fl.Reshape(num_masks, embedding_dim, embedding_dim),
)


Expand All @@ -167,47 +168,53 @@ def __init__(
embedding_dim: int,
num_layers: int,
num_mask_tokens: int,
multimask_output: bool,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_layers = num_layers
self.multimask_output = multimask_output

super().__init__(
fl.Slicing(dim=1, start=0, end=1),
fl.Squeeze(dim=0),
fl.MultiLinear(
input_dim=embedding_dim,
output_dim=num_mask_tokens + 1,
output_dim=num_mask_tokens,
inner_dim=embedding_dim,
num_layers=num_layers,
device=device,
dtype=dtype,
),
fl.Slicing(dim=-1, start=1),
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
)


class MaskDecoder(fl.Chain):
def __init__(
self,
multimask_output: bool = True,
embedding_dim: int = 256,
feed_forward_dim: int = 2048,
num_layers: int = 2,
num_output_mask: int = 3,
num_multimask_outputs: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.multimask_output = multimask_output
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_output_mask
self.feed_forward_dim = feed_forward_dim
self.num_layers = num_layers
self.num_multimask_outputs = num_multimask_outputs

# The 1 additional token is for single-output mask prediction
num_mask_tokens = self.num_multimask_outputs + 1

super().__init__(
IOUMaskEncoder(
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype
),
EmbeddingsAggregator(num_output_mask=num_output_mask),
IOUMaskEncoder(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
EmbeddingsAggregator(),
Transformer(
*(
TwoWayTransformerLayer(
Expand All @@ -225,12 +232,17 @@ def __init__(
),
fl.Parallel(
MaskPrediction(
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype
embedding_dim=embedding_dim,
num_mask_tokens=num_mask_tokens,
multimask_output=multimask_output,
device=device,
dtype=dtype,
),
IOUPrediction(
embedding_dim=embedding_dim,
num_layers=3,
num_mask_tokens=num_output_mask,
num_mask_tokens=num_mask_tokens,
multimask_output=multimask_output,
device=device,
dtype=dtype,
),
Expand Down
10 changes: 9 additions & 1 deletion src/refiners/foundationals/segment_anything/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
point_encoder: PointEncoder | None = None,
mask_encoder: MaskEncoder | None = None,
mask_decoder: MaskDecoder | None = None,
multimask_output: bool | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
Expand All @@ -243,13 +244,20 @@ def __init__(
point_encoder: The point encoder to use.
mask_encoder: The mask encoder to use.
mask_decoder: The mask decoder to use.
multimask_output: Whether to use multimask output.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
image_encoder = image_encoder or SAMViTH()
point_encoder = point_encoder or PointEncoder()
mask_encoder = mask_encoder or MaskEncoder()
mask_decoder = mask_decoder or MaskDecoder()

if mask_decoder:
assert (
mask_decoder.multimask_output == multimask_output
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output (${multimask_output})"
else:
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()

super().__init__(
image_encoder=image_encoder,
Expand Down
35 changes: 35 additions & 0 deletions tests/foundationals/segment_anything/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
return sam_h


@pytest.fixture(scope="module")
def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
return sam_h


@pytest.fixture(scope="module")
def ref_path(test_sam_path: Path) -> Path:
return test_sam_path / "test_sam_ref"
Expand Down Expand Up @@ -391,6 +398,34 @@ def test_predictor_dense_mask(
assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05)


def test_predictor_single_output(
facebook_sam_h_predictor: FacebookSAMPredictor,
sam_h_single_output: SegmentAnythingH,
truck: Image.Image,
one_prompt: SAMPrompt,
) -> None:
predictor = facebook_sam_h_predictor
predictor.set_image(np.array(truck))

facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore
**one_prompt.facebook_predict_kwargs(), # type: ignore
multimask_output=False,
)

assert len(facebook_masks) == 1

masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__)
masks = masks.squeeze(0)
scores = scores.squeeze(0)

assert len(masks) == 1

mask_prediction = masks[0].cpu()
facebook_mask = torch.as_tensor(facebook_masks[0])
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05)


def test_mask_encoder(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
) -> None:
Expand Down

0 comments on commit 68fe725

Please sign in to comment.