Skip to content

Commit

Permalink
Merge pull request #6 from membraneframework-labs/instance_segmentation
Browse files Browse the repository at this point in the history
Instance segmentation
  • Loading branch information
mkopcins authored Jul 3, 2024
2 parents cdcb796 + 377762e commit eb072c9
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
models/**/*.onnx filter=lfs diff=lfs merge=lfs -text
models/deeplab_v3_mobilenetv3_segmentation.onnx filter=lfs diff=lfs merge=lfs -text
models/maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx filter=lfs diff=lfs merge=lfs -text
models/fasterrcnn_resnet50_fpn_object_detector.onnx filter=lfs diff=lfs merge=lfs -text
models/mobilenetv3small-classifier.onnx filter=lfs diff=lfs merge=lfs -text
models/efficientnet_v2_s_classifier.onnx filter=lfs diff=lfs merge=lfs -text
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ If the model that you would like to use is missing, feel free to open the issue,
- [x] FasterRCNN ResNet50 FPN
- [x] Semantic segmentation
- [x] DeepLabV3 - MobileNetV3
- [ ] Instance segmentation
- [ ] Mask R-CNN
- [x] Instance segmentation
- [x] Mask R-CNN
- [ ] Keypoint Detection
- [ ] Keypoint R-CNN

Expand Down
42 changes: 42 additions & 0 deletions examples/1-basic-tutorial.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ The main objective of ExVision is ease of use. This sacrifices some control over
alias ExVision.Classification.MobileNetV3Small, as: Classifier
alias ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN, as: ObjectDetector
alias ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3, as: SemanticSegmentation
alias ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2, as: InstanceSegmentation

{:ok, classifier} = Classifier.load()
{:ok, object_detector} = ObjectDetector.load()
{:ok, semantic_segmentation} = SemanticSegmentation.load()
{:ok, instance_segmentation} = InstanceSegmentation.load()

Kino.nothing()
```
Expand Down Expand Up @@ -221,6 +223,46 @@ end)
|> Kino.Layout.grid(columns: 2)
```

## Instance segmentation

The objective of instance segmentation is to not only identify objects within an image on a per-pixel basis but also differentiate each specific object of the same class.

In ExVision, the output of instance segmentation models includes a bounding box with a label and a score (similar to object detection), and a binary mask for every instance detected in the image.

### Code example

In the following example, we will pass an image through the instance segmentation model and examine the individual instance masks recognized by the model.

```elixir
alias ExVision.Types.BBoxWithMask

nx_image = Image.to_nx!(image)
uniform_black = 0 |> Nx.broadcast(Nx.shape(nx_image)) |> Nx.as_type(Nx.type(nx_image))

predictions =
image
|> then(&InstanceSegmentation.run(instance_segmentation, &1))
# Get most likely predictions from the output
|> Enum.filter(fn %BBoxWithMask{score: score} -> score > 0.8 end)
|> dbg()

predictions
|> Enum.map(fn %BBoxWithMask{label: label, mask: mask} ->
# expand the mask to cover all channels
mask = Nx.broadcast(mask, Nx.shape(nx_image), axes: [0, 1])

# Cut out the mask from the original image
image = Nx.select(mask, nx_image, uniform_black)
image = Nx.as_type(image, :u8)

Kino.Layout.grid([
label |> Atom.to_string() |> Kino.Text.new(),
Kino.Image.new(image)
])
end)
|> Kino.Layout.grid(columns: 2)
```

## Next steps

After completing this tutorial you can also check out our next tutorial focusing on using models in production in process workflow [here](2-usage-as-nx-serving.livemd)
84 changes: 84 additions & 0 deletions lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do
@moduledoc """
An instance segmentation model with a ResNet-50-FPN backbone. Exported from torchvision.
"""
use ExVision.Model.Definition.Ortex,
model: "maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx",
categories: "priv/categories/coco_categories.json"

require Logger

alias ExVision.Types.BBoxWithMask

@type output_t() :: [BBoxWithMask.t()]

@impl true
def load(options \\ []) do
if Keyword.has_key?(options, :batch_size) do
Logger.warning(
"`:max_batch_size` was given, but this model can only process batch of size 1. Overriding"
)
end

options
|> Keyword.put(:batch_size, 1)
|> default_model_load()
end

@impl true
def preprocessing(img, _metdata) do
ExVision.Utils.resize(img, {224, 224})
end

@impl true
def postprocessing(
%{
"boxes_unsqueezed" => bboxes,
"labels_unsqueezed" => labels,
"masks_unsqueezed" => masks,
"scores_unsqueezed" => scores
},
metadata
) do
categories = categories()

{h, w} = metadata.original_size
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()

masks =
masks
|> Nx.backend_transfer()
|> Nx.squeeze(axes: [0, 2])
|> NxImage.resize(metadata.original_size, channels: :first)
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

[bboxes, labels, scores, masks]
|> Enum.zip()
|> Enum.filter(fn {_bbox, _label, score, _mask} -> score > 0.1 end)
|> Enum.map(fn {[x1, y1, x2, y2], label, score, mask} ->
%BBoxWithMask{
x1: x1,
y1: y1,
x2: x2,
y2: y2,
label: Enum.at(categories, label),
score: score,
mask: Nx.tensor(mask)
}
end)
end
end
2 changes: 1 addition & 1 deletion lib/ex_vision/types/bbox.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defmodule ExVision.Types.BBox do
@moduledoc """
A struct describing the bounding box returned by the detection model.
A struct describing the bounding box returned by the object detection model.
"""

@enforce_keys [:x1, :y1, :x2, :y2, :label, :score]
Expand Down
58 changes: 58 additions & 0 deletions lib/ex_vision/types/bboxwithmask.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
defmodule ExVision.Types.BBoxWithMask do
@moduledoc """
A struct describing the bounding box with mask returned by the instance segmentation model.
"""

@enforce_keys [
:x1,
:y1,
:x2,
:y2,
:label,
:score,
:mask
]
defstruct @enforce_keys

@typedoc """
A type describing the Bounding Box with Mask object.
Bounding box is a rectangle encompassing the region.
When used in instance segmentation, this box will describe the location of the object in the image.
Additionally, a binary mask represents the instance segmentation of the object.
- `x1` - x componenet of the upper left corner
- `y1` - y componenet of the upper left corner
- `x2` - x componenet of the lower right
- `y2` - y componenet of the lower right
- `score` - confidence of the predition
- `label` - label assigned to this bounding box.
- `mask` - binary mask
"""
@type t(label_t) :: %__MODULE__{
x1: number(),
y1: number(),
y2: number(),
x2: number(),
label: label_t,
score: number(),
mask: Nx.tensor()
}

@typedoc """
Exactly like `t:t/1`, but doesn't put any constraints on the `label` field:
"""
@type t() :: t(term())

@doc """
Return the width of the bounding box
"""
@spec width(t()) :: number()
def width(%__MODULE__{x1: x1, x2: x2}), do: abs(x2 - x1)

@doc """
Return the height of the bounding box
"""
@spec height(t()) :: number()
def height(%__MODULE__{y1: y1, y2: y2}), do: abs(y2 - y1)
end
2 changes: 2 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ defmodule ExVision.Mixfile do
ExVision.Classification.EfficientNet_V2_L,
ExVision.Classification.SqueezeNet1_1,
ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3,
ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2,
ExVision.ObjectDetection.Ssdlite320_MobileNetv3,
ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN
],
Expand All @@ -119,6 +120,7 @@ defmodule ExVision.Mixfile do
ExVision.Types,
ExVision.Classification,
ExVision.SemanticSegmentation,
ExVision.InstanceSegmentation,
ExVision.ObjectDetection
],
formatters: ["html"],
Expand Down
106 changes: 106 additions & 0 deletions python/exports/instance_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import argparse
from torchvision.transforms.functional import to_tensor, resize
import torch
import json
from pathlib import Path
import onnx
from onnx import helper, TensorProto
from PIL import Image


def export(model_builder, Model_Weights):
base_dir = Path(f"models/instance_segmentation/{model_builder.__name__}")
base_dir.mkdir(parents=True, exist_ok=True)

model_file = base_dir / "model.onnx"
categories_file = base_dir / "categories.json"

weights = Model_Weights.DEFAULT
model = model_builder(weights=weights)
model.eval()

categories = weights.meta["categories"]
transforms = weights.transforms()

with open(categories_file, "w") as f:
json.dump(categories, f)

onnx_input = to_tensor(Image.open("test/assets/cat.jpg")).unsqueeze(0)
onnx_input = resize(onnx_input, [224, 224])
onnx_input = transforms(onnx_input)

torch.onnx.export(
model,
onnx_input,
str(model_file),
verbose=False,
input_names=["input"],
output_names=["boxes", "labels", "scores", "masks"],
dynamic_axes={
"boxes": {0: "detections"},
"labels": {0: "detections"},
"scores": {0: "detections"},
"masks": {0: "detections"},
},
export_params=True,
)

model = onnx.load(str(model_file))

prev_names = ["boxes", "labels", "scores", "masks"]

nodes = []
for data in prev_names:
axes_init = helper.make_tensor(
name=data+"_axes",
data_type=TensorProto.INT64,
dims=[1],
vals=[0]
)
model.graph.initializer.append(axes_init)

node = helper.make_node(
op_type="Unsqueeze",
inputs=[data, data+"_axes"],
outputs=[data+"_unsqueezed"]
)
nodes.append(node)

model.graph.node.extend(nodes)

new_outputs = []
for data in prev_names:
match data:
case "boxes":
shape = [1, None, 4]
case "masks":
shape = [1, None, 1, 224, 224]
case _:
shape = [1, None]

new_output = helper.make_tensor_value_info(
name=data+"_unsqueezed",
elem_type=TensorProto.INT64 if data == "labels" else TensorProto.FLOAT,
shape=shape
)
new_outputs.append(new_output)

model.graph.output.extend(new_outputs)

for data in prev_names:
old_output = next(i for i in model.graph.output if i.name == data)
model.graph.output.remove(old_output)

onnx.save(model, str(model_file))


parser = argparse.ArgumentParser()
parser.add_argument("model")
args = parser.parse_args()

match(args.model):
case "maskrcnn_resnet50_fpn_v2":
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
export(maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights)
case _:
print("Model not found")
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2_Test do
use ExVision.Model.Case, module: ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2
use ExVision.TestUtils
alias ExVision.Types.BBoxWithMask

@impl true
def test_inference_result(result) do
assert [%BBoxWithMask{x1: 129, y1: 15, label: :cat, score: score, mask: mask}] = result
assert_floats_equal(score, 1.0)

assert_floats_equal(nx_mean(mask), 0.37)
end

defp nx_mean(t), do: t |> Nx.mean() |> Nx.to_number()
end

0 comments on commit eb072c9

Please sign in to comment.