diff --git a/.gitattributes b/.gitattributes index 926dd12..5fbe73e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,6 +1,7 @@ models/**/*.onnx filter=lfs diff=lfs merge=lfs -text models/deeplab_v3_mobilenetv3_segmentation.onnx filter=lfs diff=lfs merge=lfs -text models/maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx filter=lfs diff=lfs merge=lfs -text +models/keypointrcnn_resnet50_fpn_keypoint_detector.onnx filter=lfs diff=lfs merge=lfs -text models/fasterrcnn_resnet50_fpn_object_detector.onnx filter=lfs diff=lfs merge=lfs -text models/mobilenetv3small-classifier.onnx filter=lfs diff=lfs merge=lfs -text models/efficientnet_v2_s_classifier.onnx filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 1fc07b1..00de0e9 100644 --- a/README.md +++ b/README.md @@ -89,8 +89,8 @@ If the model that you would like to use is missing, feel free to open the issue, - [x] DeepLabV3 - MobileNetV3 - [x] Instance segmentation - [x] Mask R-CNN -- [ ] Keypoint Detection - - [ ] Keypoint R-CNN +- [x] Keypoint Detection + - [x] Keypoint R-CNN ## Copyright and License diff --git a/examples/1-basic-tutorial.livemd b/examples/1-basic-tutorial.livemd index 82a56bc..62836f9 100644 --- a/examples/1-basic-tutorial.livemd +++ b/examples/1-basic-tutorial.livemd @@ -36,11 +36,13 @@ alias ExVision.Classification.MobileNetV3Small, as: Classifier alias ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN, as: ObjectDetector alias ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3, as: SemanticSegmentation alias ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2, as: InstanceSegmentation +alias ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN, as: KeypointDetector {:ok, classifier} = Classifier.load() {:ok, object_detector} = ObjectDetector.load() {:ok, semantic_segmentation} = SemanticSegmentation.load() {:ok, instance_segmentation} = InstanceSegmentation.load() +{:ok, keypoint_detector} = KeypointDetector.load() Kino.nothing() ``` @@ -61,9 +63,11 @@ Under the hood, all of these formats will be converted to Nx's Tensors and norma A big point of ExVision over using the models directly has to be documentation and intuitive outputs. Hence, models return the following types: -* Classifier - a mapping the category into the probability: [`%{category_t() => number()}`](http://localhost:55556/ExVision.Classification.MobileNetV3.html#t:output_t/0) -* Detector - a list of bounding boxes: [`list(BBox.t())`](http://localhost:55556/ExVision.Detection.Ssdlite320_MobileNetv3.BBox.html) -* Segmentation - a mapping of category to boolean tensor determining if the pixel is part of the mask for the given class: [`%{category_t() => Nx.Tensor.t()}`](http://localhost:55556/ExVision.Segmentation.DeepLabV3_MobileNetV3.html#t:output_t/0) +* Classifier - a mapping the category into the probability: [`%{category_t() => number()}`](http://localhost:55556/ExVision.Classification.MobileNetV3Small.html#t:output_t/0) +* Object Detector - a list of bounding boxes: [`list(BBox.t())`](http://localhost:55556/ExVision.ObjectDetection.Ssdlite320_MobileNetv3.BBox.html) +* Semantic Segmentation - a mapping of category to boolean tensor determining if the pixel is part of the mask for the given class: [`%{category_t() => Nx.Tensor.t()}`](http://localhost:55556/ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3.html#t:output_t/0) +* Instance Segmentation - a list of bounding boxes with mask: [`list(BBoxWithMask.t())`](http://localhost:55556/ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2.html#t:output_t/0) +* Keypoint Detector - a list of bounding boxes with keypoints: [`list(BBoxWithKeypoints.t())`](http://localhost:55556/ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN.html#t:output_t/0) @@ -72,7 +76,29 @@ A big point of ExVision over using the models directly has to be documentation a Let's put it into practice and run some predictions on a sample image of the cat. This code is intentionally using some calls to `dbg/1` macro in order to aid with the understanding of these formats. -However, let's start with loading our test suspect. In the next cell, you can provide your own image that will be used as an example in this notebook. If you don't have anything handy, we're also providing a default image of a cat. + + +Let's start with loading our test suspect. For this purpose, we have defined a helper function that will automatically load some default images if you don't specify any. + +```elixir +defmodule ImageHandler do + def get(input, default_image) do + img_path = + case Kino.Input.read(input) do + nil -> + {:ok, file} = ExVision.Cache.lazy_get(ExVision.Cache, default_image) + file + + %{file_ref: image} -> + Kino.Input.file_path(image) + end + + Image.open!(img_path) + end +end +``` + +In the next cell, you can provide your own image that will be used as an example in this notebook. If you don't have anything handy, we're also providing a default image of a cat. @@ -83,17 +109,7 @@ input = Kino.Input.image("Image to evaluate", format: :jpeg) ```elixir -img_path = - case Kino.Input.read(input) do - nil -> - {:ok, file} = ExVision.Cache.lazy_get(ExVision.Cache, "cat.jpg") - file - - %{file_ref: image} -> - Kino.Input.file_path(image) - end - -image = Image.open!(img_path) +image = ImageHandler.get(input, "cat.jpg") ``` ### Image classification @@ -229,6 +245,8 @@ The objective of instance segmentation is to not only identify objects within an In ExVision, the output of instance segmentation models includes a bounding box with a label and a score (similar to object detection), and a binary mask for every instance detected in the image. +Extremely low probability detections (with scores lower than 0.1) will be discarded by ExVision, as they are just noise. + ### Code example In the following example, we will pass an image through the instance segmentation model and examine the individual instance masks recognized by the model. @@ -263,6 +281,102 @@ end) |> Kino.Layout.grid(columns: 2) ``` +## Keypoint detection + +In keypoint detection, we're trying to specific keypoints in the image. ExVision returns the output as a list of boudning boxes (similar to object detection) with named keypoints. Each keypoint consists of x, y coordinates and a score which is the model's certainty of that keypoint. + +ExVision will discard extremely low probability detections (with scores lower than 0.1), as they are just noise. + + + +The KeypointRCNN_ResNet50_FPN model is commonly used for detecting human body parts in images. To illustrate this, let's begin by importing an image that features people. + +```elixir +image = ImageHandler.get(input, "people.jpg") +``` + +#### Code example + +In this example, we will draw keypoints for every detection with a high enough score returned by the model, additionally we will draw a bounding box around them. + +```elixir +alias ExVision.Types.BBoxWithKeypoints + +# define skeleton pose +connections = [ + # face + {:nose, :left_eye}, + {:nose, :right_eye}, + {:left_eye, :right_eye}, + {:left_eye, :left_ear}, + {:right_eye, :right_ear}, + + # left arm + {:left_wrist, :left_elbow}, + {:left_elbow, :left_shoulder}, + + # right arm + {:right_wrist, :right_elbow}, + {:right_elbow, :right_shoulder}, + + # torso + {:left_shoulder, :right_shoulder}, + {:left_shoulder, :left_hip}, + {:right_shoulder, :right_hip}, + {:left_hip, :right_hip}, + {:left_shoulder, :left_ear}, + {:right_shoulder, :right_ear}, + + # left leg + {:left_ankle, :left_knee}, + {:left_knee, :left_hip}, + + # right leg + {:right_ankle, :right_knee}, + {:right_knee, :right_hip} +] + +# apply the model +predictions = + image + |> then(&KeypointDetector.run(keypoint_detector, &1)) + # Get most likely predictions from the output + |> Enum.filter(fn %BBoxWithKeypoints{score: score} -> score > 0.8 end) + |> dbg() + +predictions +|> Enum.reduce(image, fn prediction, image_acc -> + # draw keypoints + image_acc = + prediction.keypoints + |> Enum.reduce(image_acc, fn {_key, %{x: x, y: y}}, acc -> + Image.Draw.circle!(acc, x, y, 2, color: :red) + end) + + # draw skeleton pose + image_acc = + connections + |> Enum.reduce(image_acc, fn {from, to}, acc -> + %{x: x1, y: y1} = prediction.keypoints[from] + %{x: x2, y: y2} = prediction.keypoints[to] + + Image.Draw.line!(acc, x1, y1, x2, y2, color: :red) + end) + + # draw bounding box + Image.Draw.rect!( + image_acc, + prediction.x1, + prediction.y1, + BBoxWithKeypoints.width(prediction), + BBoxWithKeypoints.height(prediction), + fill: false, + color: :red, + stroke_width: 2 + ) +end) +``` + ## Next steps After completing this tutorial you can also check out our next tutorial focusing on using models in production in process workflow [here](2-usage-as-nx-serving.livemd) diff --git a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex new file mode 100644 index 0000000..e8e7dd9 --- /dev/null +++ b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex @@ -0,0 +1,111 @@ +defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do + @moduledoc """ + Keypoint R-CNN model with a ResNet-50-FPN backbone, exported from torchvision. + """ + use ExVision.Model.Definition.Ortex, + model: "keypointrcnn_resnet50_fpn_keypoint_detector.onnx", + categories: "priv/categories/no_person_or_person.json" + + require Logger + + alias ExVision.Types.BBoxWithKeypoints + + @typep output_t() :: [BBoxWithKeypoints.t()] + + @keypoints_names [ + :nose, + :left_eye, + :right_eye, + :left_ear, + :right_ear, + :left_shoulder, + :right_shoulder, + :left_elbow, + :right_elbow, + :left_wrist, + :right_wrist, + :left_hip, + :right_hip, + :left_knee, + :right_knee, + :left_ankle, + :right_ankle + ] + + @impl true + def load(options \\ []) do + if Keyword.has_key?(options, :batch_size) do + Logger.warning( + "`:max_batch_size` was given, but this model can only process batch of size 1. Overriding" + ) + end + + options + |> Keyword.put(:batch_size, 1) + |> default_model_load() + end + + @impl true + def preprocessing(img, _metadata) do + ExVision.Utils.resize(img, {224, 224}) + end + + @impl true + def postprocessing( + %{ + "boxes_unsqueezed" => bboxes, + "scores_unsqueezed" => scores, + "labels_unsqueezed" => labels, + "keypoints_unsqueezed" => keypoints_list, + "keypoints_scores_unsqueezed" => keypoints_scores_list + }, + metadata + ) do + categories = categories() + + {h, w} = metadata.original_size + scale_x = w / 224 + scale_y = h / 224 + + bboxes = + bboxes + |> Nx.squeeze(axes: [0]) + |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) + |> Nx.round() + |> Nx.as_type(:s64) + |> Nx.to_list() + + scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list() + labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list() + + keypoints_list = + keypoints_list + |> Nx.squeeze(axes: [0]) + |> Nx.multiply(Nx.tensor([scale_x, scale_y, 1])) + |> Nx.round() + |> Nx.as_type(:s64) + |> Nx.to_list() + + keypoints_scores_list = keypoints_scores_list |> Nx.squeeze(axes: [0]) |> Nx.to_list() + + [bboxes, scores, labels, keypoints_list, keypoints_scores_list] + |> Enum.zip() + |> Enum.filter(fn {_bbox, score, _label, _keypoints, _keypoints_scores} -> score > 0.1 end) + |> Enum.map(fn {[x1, y1, x2, y2], score, label, keypoints, keypoints_scores} -> + keypoints = + [keypoints, keypoints_scores] + |> Enum.zip() + |> Enum.map(fn {[x, y, _w], keypoint_score} -> %{x: x, y: y, score: keypoint_score} end) + + %BBoxWithKeypoints{ + x1: x1, + x2: x2, + y1: y1, + y2: y2, + score: score, + label: Enum.at(categories, label), + keypoints: [@keypoints_names, keypoints] |> Enum.zip() |> Map.new() + } + end) + end +end diff --git a/lib/ex_vision/types/bbox.ex b/lib/ex_vision/types/bbox.ex index 2406305..4cc6a45 100644 --- a/lib/ex_vision/types/bbox.ex +++ b/lib/ex_vision/types/bbox.ex @@ -17,7 +17,7 @@ defmodule ExVision.Types.BBox do - `x2` - x componenet of the lower right - `y2` - y componenet of the lower right - `score` - confidence of the predition - - `label` - label assigned to this bounding box. + - `label` - label assigned to this bounding box """ @type t(label_t) :: %__MODULE__{ x1: number(), diff --git a/lib/ex_vision/types/bboxwithkeypoints.ex b/lib/ex_vision/types/bboxwithkeypoints.ex new file mode 100644 index 0000000..89c5148 --- /dev/null +++ b/lib/ex_vision/types/bboxwithkeypoints.ex @@ -0,0 +1,86 @@ +defmodule ExVision.Types.BBoxWithKeypoints do + @moduledoc """ + A struct describing the bounding box with keypoints returned by the keypoint detection model. + """ + + @enforce_keys [ + :x1, + :y1, + :x2, + :y2, + :label, + :score, + :keypoints + ] + defstruct @enforce_keys + + @typedoc """ + A type describing the Bounding Box object. + + Bounding box is a rectangle encompassing the region. + When used in object detectors, this box will describe the location of the object in the image. + It also includes keypoints. Each keypoint has a predefined atom as its name. + + - `x1` - x componenet of the upper left corner + - `y1` - y componenet of the upper left corner + - `x2` - x componenet of the lower right + - `y2` - y componenet of the lower right + - `label` - label assigned to this bounding box + - `score` - confidence of the predition + - `keypoints` - a map where keys are predefined names (represented as atoms) denoting the specific keypoints (body parts). The values associated with each key are another map, which contains the following: + - `:x`: The x-coordinate of the keypoint + - `:y`: The y-coordinate of the keypoint + - `:score`: The confidence score of the predicted keypoint + + Keypoint atom names include: + - `:nose` + - `:left_eye` + - `:right_eye` + - `:left_ear` + - `:right_ear` + - `:left_shoulder` + - `:right_shoulder` + - `:left_elbow` + - `:right_elbow` + - `:left_wrist` + - `:right_wrist` + - `:left_hip` + - `:right_hip` + - `:left_knee` + - `:right_knee` + - `:left_ankle` + - `:right_ankle` + """ + @type t(label_t) :: %__MODULE__{ + x1: number(), + y1: number(), + y2: number(), + x2: number(), + label: label_t, + score: number(), + keypoints: %{ + atom() => %{ + x: number(), + y: number(), + score: number() + } + } + } + + @typedoc """ + Exactly like `t:t/1`, but doesn't put any constraints on the `label` field: + """ + @type t() :: t(term()) + + @doc """ + Return the width of the bounding box + """ + @spec width(t()) :: number() + def width(%__MODULE__{x1: x1, x2: x2}), do: abs(x2 - x1) + + @doc """ + Return the height of the bounding box + """ + @spec height(t()) :: number() + def height(%__MODULE__{y1: y1, y2: y2}), do: abs(y2 - y1) +end diff --git a/lib/ex_vision/types/bboxwithmask.ex b/lib/ex_vision/types/bboxwithmask.ex index a2e2f4c..3dd5303 100644 --- a/lib/ex_vision/types/bboxwithmask.ex +++ b/lib/ex_vision/types/bboxwithmask.ex @@ -26,7 +26,7 @@ defmodule ExVision.Types.BBoxWithMask do - `x2` - x componenet of the lower right - `y2` - y componenet of the lower right - `score` - confidence of the predition - - `label` - label assigned to this bounding box. + - `label` - label assigned to this bounding box - `mask` - binary mask """ @type t(label_t) :: %__MODULE__{ diff --git a/mix.exs b/mix.exs index d5895ed..3613c90 100644 --- a/mix.exs +++ b/mix.exs @@ -101,7 +101,8 @@ defmodule ExVision.Mixfile do ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3, ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2, ExVision.ObjectDetection.Ssdlite320_MobileNetv3, - ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN + ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN, + ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN ], Types: [ ExVision.Types, @@ -121,7 +122,8 @@ defmodule ExVision.Mixfile do ExVision.Classification, ExVision.SemanticSegmentation, ExVision.InstanceSegmentation, - ExVision.ObjectDetection + ExVision.ObjectDetection, + ExVision.KeypointDetection ], formatters: ["html"], source_ref: "v#{@version}" diff --git a/priv/categories/no_person_or_person.json b/priv/categories/no_person_or_person.json new file mode 100644 index 0000000..0c9fa27 --- /dev/null +++ b/priv/categories/no_person_or_person.json @@ -0,0 +1 @@ +["no person", "person"] \ No newline at end of file diff --git a/python/exports/keypoint_detection.py b/python/exports/keypoint_detection.py new file mode 100644 index 0000000..50dd27e --- /dev/null +++ b/python/exports/keypoint_detection.py @@ -0,0 +1,112 @@ +import argparse +from torchvision.transforms.functional import to_tensor, resize +import torch +import json +from pathlib import Path +import onnx +from onnx import helper, TensorProto +from PIL import Image + + +def export(model_builder, Model_Weights): + base_dir = Path(f"models/keypoint_detection/{model_builder.__name__}") + base_dir.mkdir(parents=True, exist_ok=True) + + model_file = base_dir / "model.onnx" + categories_file = base_dir / "categories.json" + + weights = Model_Weights.DEFAULT + model = model_builder(weights=weights) + model.eval() + + categories = weights.meta["categories"] + transforms = weights.transforms() + + with open(categories_file, "w") as f: + json.dump(categories, f) + + onnx_input = to_tensor(Image.open("test/assets/cat.jpg")).unsqueeze(0) + onnx_input = resize(onnx_input, [224, 224]) + onnx_input = transforms(onnx_input) + + torch.onnx.export( + model, + onnx_input, + str(model_file), + verbose=False, + input_names=["input"], + output_names=["boxes", "labels", "scores", + "keypoints", "keypoints_scores"], + dynamic_axes={ + "boxes": {0: "detections"}, + "labels": {0: "detections"}, + "scores": {0: "detections"}, + "keypoints": {0: "detections"}, + "keypoints_scores": {0: "detections"} + }, + export_params=True, + ) + + output_names = ["boxes", "labels", "scores", + "keypoints", "keypoints_scores"] + + model = onnx.load(str(model_file)) + + nodes = [] + for output_name in output_names: + axes_init = helper.make_tensor( + name=output_name+"_axes", + data_type=TensorProto.INT64, + dims=[1], + vals=[0] + ) + model.graph.initializer.append(axes_init) + + node = helper.make_node( + op_type="Unsqueeze", + inputs=[output_name, output_name+"_axes"], + outputs=[output_name+"_unsqueezed"] + ) + nodes.append(node) + + model.graph.node.extend(nodes) + + new_outputs = [] + for output_name in output_names: + match output_name: + case "boxes": + shape = [1, None, 4] + case "keypoints": + shape = [1, None, 17, 3] + case "keypoints_scores": + shape = [1, None, 17] + case _: + shape = [1, None] + + new_output = helper.make_tensor_value_info( + name=output_name+"_unsqueezed", + elem_type=TensorProto.INT64 if output_name == "labels" else TensorProto.FLOAT, + shape=shape + ) + new_outputs.append(new_output) + + model.graph.output.extend(new_outputs) + + for output_name in output_names: + old_output = next( + i for i in model.graph.output if i.name == output_name) + model.graph.output.remove(old_output) + + onnx.save(model, str(model_file)) + + +parser = argparse.ArgumentParser() +parser.add_argument("model") +args = parser.parse_args() + +match(args.model): + case "keypointrcnn_resnet50_fpn": + from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights + export(keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights) + case _: + print("Model not found") diff --git a/test/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn_test.exs b/test/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn_test.exs new file mode 100644 index 0000000..f864d78 --- /dev/null +++ b/test/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn_test.exs @@ -0,0 +1,33 @@ +defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPNTest do + use ExVision.Model.Case, module: ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN + use ExVision.TestUtils + alias ExVision.Types.BBoxWithKeypoints + + @impl true + def test_inference_result(result) do + assert [ + %BBoxWithKeypoints{ + x1: 113, + y1: 15, + label: :person, + score: score1, + keypoints: keypoints + }, + %BBoxWithKeypoints{ + x1: 141, + y1: 167, + label: :person, + score: score2 + } + ] = result + + assert_floats_equal(score1, 0.46) + assert_floats_equal(score2, 0.29) + + assert max_keypoint_score(keypoints) < 5 + end + + defp max_keypoint_score(keypoints) do + keypoints |> Enum.map(fn {_name, %{score: score}} -> score end) |> Enum.max() + end +end diff --git a/test/support/exvision/model/case.ex b/test/support/exvision/model/case.ex index dbeb13e..7fe58c2 100644 --- a/test/support/exvision/model/case.ex +++ b/test/support/exvision/model/case.ex @@ -79,7 +79,7 @@ defmodule ExVision.Model.Case do assert unquote(opts[:module]).batched_run( __MODULE__.TestProcess1, - Nx.iota({3, 124, 124}, type: :u8) + unquote(@img_path) ) end end