diff --git a/.gitattributes b/.gitattributes index dacbbb3..926dd12 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,5 +1,6 @@ models/**/*.onnx filter=lfs diff=lfs merge=lfs -text models/deeplab_v3_mobilenetv3_segmentation.onnx filter=lfs diff=lfs merge=lfs -text +models/maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx filter=lfs diff=lfs merge=lfs -text models/fasterrcnn_resnet50_fpn_object_detector.onnx filter=lfs diff=lfs merge=lfs -text models/mobilenetv3small-classifier.onnx filter=lfs diff=lfs merge=lfs -text models/efficientnet_v2_s_classifier.onnx filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 5f238e0..1fc07b1 100644 --- a/README.md +++ b/README.md @@ -87,8 +87,8 @@ If the model that you would like to use is missing, feel free to open the issue, - [x] FasterRCNN ResNet50 FPN - [x] Semantic segmentation - [x] DeepLabV3 - MobileNetV3 -- [ ] Instance segmentation - - [ ] Mask R-CNN +- [x] Instance segmentation + - [x] Mask R-CNN - [ ] Keypoint Detection - [ ] Keypoint R-CNN diff --git a/examples/1-basic-tutorial.livemd b/examples/1-basic-tutorial.livemd index 6d4e8ad..82a56bc 100644 --- a/examples/1-basic-tutorial.livemd +++ b/examples/1-basic-tutorial.livemd @@ -35,10 +35,12 @@ The main objective of ExVision is ease of use. This sacrifices some control over alias ExVision.Classification.MobileNetV3Small, as: Classifier alias ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN, as: ObjectDetector alias ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3, as: SemanticSegmentation +alias ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2, as: InstanceSegmentation {:ok, classifier} = Classifier.load() {:ok, object_detector} = ObjectDetector.load() {:ok, semantic_segmentation} = SemanticSegmentation.load() +{:ok, instance_segmentation} = InstanceSegmentation.load() Kino.nothing() ``` @@ -221,6 +223,46 @@ end) |> Kino.Layout.grid(columns: 2) ``` +## Instance segmentation + +The objective of instance segmentation is to not only identify objects within an image on a per-pixel basis but also differentiate each specific object of the same class. + +In ExVision, the output of instance segmentation models includes a bounding box with a label and a score (similar to object detection), and a binary mask for every instance detected in the image. + +### Code example + +In the following example, we will pass an image through the instance segmentation model and examine the individual instance masks recognized by the model. + +```elixir +alias ExVision.Types.BBoxWithMask + +nx_image = Image.to_nx!(image) +uniform_black = 0 |> Nx.broadcast(Nx.shape(nx_image)) |> Nx.as_type(Nx.type(nx_image)) + +predictions = + image + |> then(&InstanceSegmentation.run(instance_segmentation, &1)) + # Get most likely predictions from the output + |> Enum.filter(fn %BBoxWithMask{score: score} -> score > 0.8 end) + |> dbg() + +predictions +|> Enum.map(fn %BBoxWithMask{label: label, mask: mask} -> + # expand the mask to cover all channels + mask = Nx.broadcast(mask, Nx.shape(nx_image), axes: [0, 1]) + + # Cut out the mask from the original image + image = Nx.select(mask, nx_image, uniform_black) + image = Nx.as_type(image, :u8) + + Kino.Layout.grid([ + label |> Atom.to_string() |> Kino.Text.new(), + Kino.Image.new(image) + ]) +end) +|> Kino.Layout.grid(columns: 2) +``` + ## Next steps After completing this tutorial you can also check out our next tutorial focusing on using models in production in process workflow [here](2-usage-as-nx-serving.livemd) diff --git a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex new file mode 100644 index 0000000..d3ebb33 --- /dev/null +++ b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex @@ -0,0 +1,84 @@ +defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do + @moduledoc """ + An instance segmentation model with a ResNet-50-FPN backbone. Exported from torchvision. + """ + use ExVision.Model.Definition.Ortex, + model: "maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx", + categories: "priv/categories/coco_categories.json" + + require Logger + + alias ExVision.Types.BBoxWithMask + + @type output_t() :: [BBoxWithMask.t()] + + @impl true + def load(options \\ []) do + if Keyword.has_key?(options, :batch_size) do + Logger.warning( + "`:max_batch_size` was given, but this model can only process batch of size 1. Overriding" + ) + end + + options + |> Keyword.put(:batch_size, 1) + |> default_model_load() + end + + @impl true + def preprocessing(img, _metdata) do + ExVision.Utils.resize(img, {224, 224}) + end + + @impl true + def postprocessing( + %{ + "boxes_unsqueezed" => bboxes, + "labels_unsqueezed" => labels, + "masks_unsqueezed" => masks, + "scores_unsqueezed" => scores + }, + metadata + ) do + categories = categories() + + {h, w} = metadata.original_size + scale_x = w / 224 + scale_y = h / 224 + + bboxes = + bboxes + |> Nx.squeeze(axes: [0]) + |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) + |> Nx.round() + |> Nx.as_type(:s64) + |> Nx.to_list() + + scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list() + labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list() + + masks = + masks + |> Nx.backend_transfer() + |> Nx.squeeze(axes: [0, 2]) + |> NxImage.resize(metadata.original_size, channels: :first) + |> Nx.round() + |> Nx.as_type(:s64) + |> Nx.to_list() + + [bboxes, labels, scores, masks] + |> Enum.zip() + |> Enum.filter(fn {_bbox, _label, score, _mask} -> score > 0.1 end) + |> Enum.map(fn {[x1, y1, x2, y2], label, score, mask} -> + %BBoxWithMask{ + x1: x1, + y1: y1, + x2: x2, + y2: y2, + label: Enum.at(categories, label), + score: score, + mask: Nx.tensor(mask) + } + end) + end +end diff --git a/lib/ex_vision/types/bbox.ex b/lib/ex_vision/types/bbox.ex index 498a835..2406305 100644 --- a/lib/ex_vision/types/bbox.ex +++ b/lib/ex_vision/types/bbox.ex @@ -1,6 +1,6 @@ defmodule ExVision.Types.BBox do @moduledoc """ - A struct describing the bounding box returned by the detection model. + A struct describing the bounding box returned by the object detection model. """ @enforce_keys [:x1, :y1, :x2, :y2, :label, :score] diff --git a/lib/ex_vision/types/bboxwithmask.ex b/lib/ex_vision/types/bboxwithmask.ex new file mode 100644 index 0000000..a2e2f4c --- /dev/null +++ b/lib/ex_vision/types/bboxwithmask.ex @@ -0,0 +1,58 @@ +defmodule ExVision.Types.BBoxWithMask do + @moduledoc """ + A struct describing the bounding box with mask returned by the instance segmentation model. + """ + + @enforce_keys [ + :x1, + :y1, + :x2, + :y2, + :label, + :score, + :mask + ] + defstruct @enforce_keys + + @typedoc """ + A type describing the Bounding Box with Mask object. + + Bounding box is a rectangle encompassing the region. + When used in instance segmentation, this box will describe the location of the object in the image. + Additionally, a binary mask represents the instance segmentation of the object. + + - `x1` - x componenet of the upper left corner + - `y1` - y componenet of the upper left corner + - `x2` - x componenet of the lower right + - `y2` - y componenet of the lower right + - `score` - confidence of the predition + - `label` - label assigned to this bounding box. + - `mask` - binary mask + """ + @type t(label_t) :: %__MODULE__{ + x1: number(), + y1: number(), + y2: number(), + x2: number(), + label: label_t, + score: number(), + mask: Nx.tensor() + } + + @typedoc """ + Exactly like `t:t/1`, but doesn't put any constraints on the `label` field: + """ + @type t() :: t(term()) + + @doc """ + Return the width of the bounding box + """ + @spec width(t()) :: number() + def width(%__MODULE__{x1: x1, x2: x2}), do: abs(x2 - x1) + + @doc """ + Return the height of the bounding box + """ + @spec height(t()) :: number() + def height(%__MODULE__{y1: y1, y2: y2}), do: abs(y2 - y1) +end diff --git a/mix.exs b/mix.exs index d0b220f..d5895ed 100644 --- a/mix.exs +++ b/mix.exs @@ -99,6 +99,7 @@ defmodule ExVision.Mixfile do ExVision.Classification.EfficientNet_V2_L, ExVision.Classification.SqueezeNet1_1, ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3, + ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2, ExVision.ObjectDetection.Ssdlite320_MobileNetv3, ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN ], @@ -119,6 +120,7 @@ defmodule ExVision.Mixfile do ExVision.Types, ExVision.Classification, ExVision.SemanticSegmentation, + ExVision.InstanceSegmentation, ExVision.ObjectDetection ], formatters: ["html"], diff --git a/python/exports/instance_segmentation.py b/python/exports/instance_segmentation.py new file mode 100644 index 0000000..20d59ce --- /dev/null +++ b/python/exports/instance_segmentation.py @@ -0,0 +1,106 @@ +import argparse +from torchvision.transforms.functional import to_tensor, resize +import torch +import json +from pathlib import Path +import onnx +from onnx import helper, TensorProto +from PIL import Image + + +def export(model_builder, Model_Weights): + base_dir = Path(f"models/instance_segmentation/{model_builder.__name__}") + base_dir.mkdir(parents=True, exist_ok=True) + + model_file = base_dir / "model.onnx" + categories_file = base_dir / "categories.json" + + weights = Model_Weights.DEFAULT + model = model_builder(weights=weights) + model.eval() + + categories = weights.meta["categories"] + transforms = weights.transforms() + + with open(categories_file, "w") as f: + json.dump(categories, f) + + onnx_input = to_tensor(Image.open("test/assets/cat.jpg")).unsqueeze(0) + onnx_input = resize(onnx_input, [224, 224]) + onnx_input = transforms(onnx_input) + + torch.onnx.export( + model, + onnx_input, + str(model_file), + verbose=False, + input_names=["input"], + output_names=["boxes", "labels", "scores", "masks"], + dynamic_axes={ + "boxes": {0: "detections"}, + "labels": {0: "detections"}, + "scores": {0: "detections"}, + "masks": {0: "detections"}, + }, + export_params=True, + ) + + model = onnx.load(str(model_file)) + + prev_names = ["boxes", "labels", "scores", "masks"] + + nodes = [] + for data in prev_names: + axes_init = helper.make_tensor( + name=data+"_axes", + data_type=TensorProto.INT64, + dims=[1], + vals=[0] + ) + model.graph.initializer.append(axes_init) + + node = helper.make_node( + op_type="Unsqueeze", + inputs=[data, data+"_axes"], + outputs=[data+"_unsqueezed"] + ) + nodes.append(node) + + model.graph.node.extend(nodes) + + new_outputs = [] + for data in prev_names: + match data: + case "boxes": + shape = [1, None, 4] + case "masks": + shape = [1, None, 1, 224, 224] + case _: + shape = [1, None] + + new_output = helper.make_tensor_value_info( + name=data+"_unsqueezed", + elem_type=TensorProto.INT64 if data == "labels" else TensorProto.FLOAT, + shape=shape + ) + new_outputs.append(new_output) + + model.graph.output.extend(new_outputs) + + for data in prev_names: + old_output = next(i for i in model.graph.output if i.name == data) + model.graph.output.remove(old_output) + + onnx.save(model, str(model_file)) + + +parser = argparse.ArgumentParser() +parser.add_argument("model") +args = parser.parse_args() + +match(args.model): + case "maskrcnn_resnet50_fpn_v2": + from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights + export(maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights) + case _: + print("Model not found") diff --git a/test/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2_test.exs b/test/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2_test.exs new file mode 100644 index 0000000..c49a889 --- /dev/null +++ b/test/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2_test.exs @@ -0,0 +1,15 @@ +defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2_Test do + use ExVision.Model.Case, module: ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 + use ExVision.TestUtils + alias ExVision.Types.BBoxWithMask + + @impl true + def test_inference_result(result) do + assert [%BBoxWithMask{x1: 129, y1: 15, label: :cat, score: score, mask: mask}] = result + assert_floats_equal(score, 1.0) + + assert_floats_equal(nx_mean(mask), 0.37) + end + + defp nx_mean(t), do: t |> Nx.mean() |> Nx.to_number() +end