-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from membraneframework-labs/instance_segmentation
Instance segmentation
- Loading branch information
Showing
9 changed files
with
311 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do | ||
@moduledoc """ | ||
An instance segmentation model with a ResNet-50-FPN backbone. Exported from torchvision. | ||
""" | ||
use ExVision.Model.Definition.Ortex, | ||
model: "maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx", | ||
categories: "priv/categories/coco_categories.json" | ||
|
||
require Logger | ||
|
||
alias ExVision.Types.BBoxWithMask | ||
|
||
@type output_t() :: [BBoxWithMask.t()] | ||
|
||
@impl true | ||
def load(options \\ []) do | ||
if Keyword.has_key?(options, :batch_size) do | ||
Logger.warning( | ||
"`:max_batch_size` was given, but this model can only process batch of size 1. Overriding" | ||
) | ||
end | ||
|
||
options | ||
|> Keyword.put(:batch_size, 1) | ||
|> default_model_load() | ||
end | ||
|
||
@impl true | ||
def preprocessing(img, _metdata) do | ||
ExVision.Utils.resize(img, {224, 224}) | ||
end | ||
|
||
@impl true | ||
def postprocessing( | ||
%{ | ||
"boxes_unsqueezed" => bboxes, | ||
"labels_unsqueezed" => labels, | ||
"masks_unsqueezed" => masks, | ||
"scores_unsqueezed" => scores | ||
}, | ||
metadata | ||
) do | ||
categories = categories() | ||
|
||
{h, w} = metadata.original_size | ||
scale_x = w / 224 | ||
scale_y = h / 224 | ||
|
||
bboxes = | ||
bboxes | ||
|> Nx.squeeze(axes: [0]) | ||
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) | ||
|> Nx.round() | ||
|> Nx.as_type(:s64) | ||
|> Nx.to_list() | ||
|
||
scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list() | ||
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list() | ||
|
||
masks = | ||
masks | ||
|> Nx.backend_transfer() | ||
|> Nx.squeeze(axes: [0, 2]) | ||
|> NxImage.resize(metadata.original_size, channels: :first) | ||
|> Nx.round() | ||
|> Nx.as_type(:s64) | ||
|> Nx.to_list() | ||
|
||
[bboxes, labels, scores, masks] | ||
|> Enum.zip() | ||
|> Enum.filter(fn {_bbox, _label, score, _mask} -> score > 0.1 end) | ||
|> Enum.map(fn {[x1, y1, x2, y2], label, score, mask} -> | ||
%BBoxWithMask{ | ||
x1: x1, | ||
y1: y1, | ||
x2: x2, | ||
y2: y2, | ||
label: Enum.at(categories, label), | ||
score: score, | ||
mask: Nx.tensor(mask) | ||
} | ||
end) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
defmodule ExVision.Types.BBoxWithMask do | ||
@moduledoc """ | ||
A struct describing the bounding box with mask returned by the instance segmentation model. | ||
""" | ||
|
||
@enforce_keys [ | ||
:x1, | ||
:y1, | ||
:x2, | ||
:y2, | ||
:label, | ||
:score, | ||
:mask | ||
] | ||
defstruct @enforce_keys | ||
|
||
@typedoc """ | ||
A type describing the Bounding Box with Mask object. | ||
Bounding box is a rectangle encompassing the region. | ||
When used in instance segmentation, this box will describe the location of the object in the image. | ||
Additionally, a binary mask represents the instance segmentation of the object. | ||
- `x1` - x componenet of the upper left corner | ||
- `y1` - y componenet of the upper left corner | ||
- `x2` - x componenet of the lower right | ||
- `y2` - y componenet of the lower right | ||
- `score` - confidence of the predition | ||
- `label` - label assigned to this bounding box. | ||
- `mask` - binary mask | ||
""" | ||
@type t(label_t) :: %__MODULE__{ | ||
x1: number(), | ||
y1: number(), | ||
y2: number(), | ||
x2: number(), | ||
label: label_t, | ||
score: number(), | ||
mask: Nx.tensor() | ||
} | ||
|
||
@typedoc """ | ||
Exactly like `t:t/1`, but doesn't put any constraints on the `label` field: | ||
""" | ||
@type t() :: t(term()) | ||
|
||
@doc """ | ||
Return the width of the bounding box | ||
""" | ||
@spec width(t()) :: number() | ||
def width(%__MODULE__{x1: x1, x2: x2}), do: abs(x2 - x1) | ||
|
||
@doc """ | ||
Return the height of the bounding box | ||
""" | ||
@spec height(t()) :: number() | ||
def height(%__MODULE__{y1: y1, y2: y2}), do: abs(y2 - y1) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import argparse | ||
from torchvision.transforms.functional import to_tensor, resize | ||
import torch | ||
import json | ||
from pathlib import Path | ||
import onnx | ||
from onnx import helper, TensorProto | ||
from PIL import Image | ||
|
||
|
||
def export(model_builder, Model_Weights): | ||
base_dir = Path(f"models/instance_segmentation/{model_builder.__name__}") | ||
base_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
model_file = base_dir / "model.onnx" | ||
categories_file = base_dir / "categories.json" | ||
|
||
weights = Model_Weights.DEFAULT | ||
model = model_builder(weights=weights) | ||
model.eval() | ||
|
||
categories = weights.meta["categories"] | ||
transforms = weights.transforms() | ||
|
||
with open(categories_file, "w") as f: | ||
json.dump(categories, f) | ||
|
||
onnx_input = to_tensor(Image.open("test/assets/cat.jpg")).unsqueeze(0) | ||
onnx_input = resize(onnx_input, [224, 224]) | ||
onnx_input = transforms(onnx_input) | ||
|
||
torch.onnx.export( | ||
model, | ||
onnx_input, | ||
str(model_file), | ||
verbose=False, | ||
input_names=["input"], | ||
output_names=["boxes", "labels", "scores", "masks"], | ||
dynamic_axes={ | ||
"boxes": {0: "detections"}, | ||
"labels": {0: "detections"}, | ||
"scores": {0: "detections"}, | ||
"masks": {0: "detections"}, | ||
}, | ||
export_params=True, | ||
) | ||
|
||
model = onnx.load(str(model_file)) | ||
|
||
prev_names = ["boxes", "labels", "scores", "masks"] | ||
|
||
nodes = [] | ||
for data in prev_names: | ||
axes_init = helper.make_tensor( | ||
name=data+"_axes", | ||
data_type=TensorProto.INT64, | ||
dims=[1], | ||
vals=[0] | ||
) | ||
model.graph.initializer.append(axes_init) | ||
|
||
node = helper.make_node( | ||
op_type="Unsqueeze", | ||
inputs=[data, data+"_axes"], | ||
outputs=[data+"_unsqueezed"] | ||
) | ||
nodes.append(node) | ||
|
||
model.graph.node.extend(nodes) | ||
|
||
new_outputs = [] | ||
for data in prev_names: | ||
match data: | ||
case "boxes": | ||
shape = [1, None, 4] | ||
case "masks": | ||
shape = [1, None, 1, 224, 224] | ||
case _: | ||
shape = [1, None] | ||
|
||
new_output = helper.make_tensor_value_info( | ||
name=data+"_unsqueezed", | ||
elem_type=TensorProto.INT64 if data == "labels" else TensorProto.FLOAT, | ||
shape=shape | ||
) | ||
new_outputs.append(new_output) | ||
|
||
model.graph.output.extend(new_outputs) | ||
|
||
for data in prev_names: | ||
old_output = next(i for i in model.graph.output if i.name == data) | ||
model.graph.output.remove(old_output) | ||
|
||
onnx.save(model, str(model_file)) | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("model") | ||
args = parser.parse_args() | ||
|
||
match(args.model): | ||
case "maskrcnn_resnet50_fpn_v2": | ||
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights | ||
export(maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights) | ||
case _: | ||
print("Model not found") |
15 changes: 15 additions & 0 deletions
15
test/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2_test.exs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2_Test do | ||
use ExVision.Model.Case, module: ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 | ||
use ExVision.TestUtils | ||
alias ExVision.Types.BBoxWithMask | ||
|
||
@impl true | ||
def test_inference_result(result) do | ||
assert [%BBoxWithMask{x1: 129, y1: 15, label: :cat, score: score, mask: mask}] = result | ||
assert_floats_equal(score, 1.0) | ||
|
||
assert_floats_equal(nx_mean(mask), 0.37) | ||
end | ||
|
||
defp nx_mean(t), do: t |> Nx.mean() |> Nx.to_number() | ||
end |