From 81346e3938a350cec781f48e7edc00c7e18d8a8b Mon Sep 17 00:00:00 2001 From: msluszniak Date: Wed, 3 Jul 2024 13:28:51 +0000 Subject: [PATCH 1/9] Clean up code --- lib/ex_vision/classification/efficientnet_v2_l.ex | 4 ++-- lib/ex_vision/classification/efficientnet_v2_m.ex | 4 ++-- lib/ex_vision/classification/efficientnet_v2_s.ex | 4 ++-- lib/ex_vision/classification/generic_classifier.ex | 2 +- lib/ex_vision/classification/mobilenet_v3_small.ex | 4 ++-- lib/ex_vision/classification/squeezenet1_1.ex | 4 ++-- lib/ex_vision/utils.ex | 8 ++------ 7 files changed, 13 insertions(+), 17 deletions(-) diff --git a/lib/ex_vision/classification/efficientnet_v2_l.ex b/lib/ex_vision/classification/efficientnet_v2_l.ex index a4e86db..1990530 100644 --- a/lib/ex_vision/classification/efficientnet_v2_l.ex +++ b/lib/ex_vision/classification/efficientnet_v2_l.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_L do image |> ExVision.Utils.resize({480, 480}) |> NxImage.normalize( - Nx.tensor([0.5, 0.5, 0.5]), - Nx.tensor([0.5, 0.5, 0.5]), + Nx.f32([0.5, 0.5, 0.5]), + Nx.f32([0.5, 0.5, 0.5]), channels: :first ) end diff --git a/lib/ex_vision/classification/efficientnet_v2_m.ex b/lib/ex_vision/classification/efficientnet_v2_m.ex index 83ccafd..eb9ec43 100644 --- a/lib/ex_vision/classification/efficientnet_v2_m.ex +++ b/lib/ex_vision/classification/efficientnet_v2_m.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_M do image |> ExVision.Utils.resize({480, 480}) |> NxImage.normalize( - Nx.tensor([0.485, 0.456, 0.406]), - Nx.tensor([0.229, 0.224, 0.225]), + Nx.f32([0.485, 0.456, 0.406]), + Nx.f32([0.229, 0.224, 0.225]), channels: :first ) end diff --git a/lib/ex_vision/classification/efficientnet_v2_s.ex b/lib/ex_vision/classification/efficientnet_v2_s.ex index 277f9ce..d11da95 100644 --- a/lib/ex_vision/classification/efficientnet_v2_s.ex +++ b/lib/ex_vision/classification/efficientnet_v2_s.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_S do image |> ExVision.Utils.resize({384, 384}) |> NxImage.normalize( - Nx.tensor([0.485, 0.456, 0.406]), - Nx.tensor([0.229, 0.224, 0.225]), + Nx.f32([0.485, 0.456, 0.406]), + Nx.f32([0.229, 0.224, 0.225]), channels: :first ) end diff --git a/lib/ex_vision/classification/generic_classifier.ex b/lib/ex_vision/classification/generic_classifier.ex index 5e7b9d1..d59b2f5 100644 --- a/lib/ex_vision/classification/generic_classifier.ex +++ b/lib/ex_vision/classification/generic_classifier.ex @@ -15,7 +15,7 @@ defmodule ExVision.Classification.GenericClassifier do scores |> Nx.backend_transfer() |> Nx.flatten() - |> Utils.softmax() + |> Axon.Activations.softmax(axis: [0]) |> Nx.to_flat_list() |> then(&Enum.zip(categories, &1)) |> Map.new() diff --git a/lib/ex_vision/classification/mobilenet_v3_small.ex b/lib/ex_vision/classification/mobilenet_v3_small.ex index 2c4d207..e65d0bd 100644 --- a/lib/ex_vision/classification/mobilenet_v3_small.ex +++ b/lib/ex_vision/classification/mobilenet_v3_small.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.MobileNetV3Small do image |> ExVision.Utils.resize({224, 224}) |> NxImage.normalize( - Nx.tensor([0.485, 0.456, 0.406]), - Nx.tensor([0.229, 0.224, 0.225]), + Nx.f32([0.485, 0.456, 0.406]), + Nx.f32([0.229, 0.224, 0.225]), channels: :first ) end diff --git a/lib/ex_vision/classification/squeezenet1_1.ex b/lib/ex_vision/classification/squeezenet1_1.ex index 6865847..3711491 100644 --- a/lib/ex_vision/classification/squeezenet1_1.ex +++ b/lib/ex_vision/classification/squeezenet1_1.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.SqueezeNet1_1 do image |> ExVision.Utils.resize({224, 224}) |> NxImage.normalize( - Nx.tensor([0.485, 0.456, 0.406]), - Nx.tensor([0.229, 0.224, 0.225]), + Nx.f32([0.485, 0.456, 0.406]), + Nx.f32([0.229, 0.224, 0.225]), channels: :first ) end diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index 15e3482..1003257 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -86,8 +86,8 @@ defmodule ExVision.Utils do defp ensure_grad_3(tensor) do tensor - |> Nx.shape() - |> tuple_size() + + Nx.rank() |> case do 3 -> [tensor] 4 -> tensor |> Nx.to_batched(1) |> Stream.map(&Nx.squeeze(&1, axes: [0])) |> Enum.to_list() @@ -149,10 +149,6 @@ defmodule ExVision.Utils do Enum.map(outputs, fn {name, _type, _shape} -> name end) end - defn softmax(x) do - Nx.divide(Nx.exp(x), Nx.sum(Nx.exp(x))) - end - @spec batched_run(atom(), ExVision.Model.input_t()) :: ExVision.Model.output_t() def batched_run(process_name, input) when is_list(input) do Nx.Serving.batched_run(process_name, input) From cc898ade560816c39ad1f446a67cc0d5d40c46ad Mon Sep 17 00:00:00 2001 From: msluszniak Date: Wed, 3 Jul 2024 13:46:50 +0000 Subject: [PATCH 2/9] Add pipe symbol --- lib/ex_vision/utils.ex | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index 1003257..e8813c3 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -86,8 +86,7 @@ defmodule ExVision.Utils do defp ensure_grad_3(tensor) do tensor - - Nx.rank() + |> Nx.rank() |> case do 3 -> [tensor] 4 -> tensor |> Nx.to_batched(1) |> Stream.map(&Nx.squeeze(&1, axes: [0])) |> Enum.to_list() From 6999542b8a6ad27db1f62d5fbb1174d3df229c97 Mon Sep 17 00:00:00 2001 From: msluszniak Date: Wed, 3 Jul 2024 14:06:05 +0000 Subject: [PATCH 3/9] Remove unused alias --- lib/ex_vision/classification/generic_classifier.ex | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/ex_vision/classification/generic_classifier.ex b/lib/ex_vision/classification/generic_classifier.ex index d59b2f5..0d46f44 100644 --- a/lib/ex_vision/classification/generic_classifier.ex +++ b/lib/ex_vision/classification/generic_classifier.ex @@ -4,8 +4,6 @@ defmodule ExVision.Classification.GenericClassifier do # Contains a default implementation of post processing for TorchVision classifiers # To use: `use ExVision.Classification.GenericClassifier` - alias ExVision.Utils - alias ExVision.Types.ImageMetadata @typep output_t() :: %{atom() => number()} From b9a9db7223994058fcfce646d16aabd022a9d726 Mon Sep 17 00:00:00 2001 From: msluszniak Date: Wed, 3 Jul 2024 14:10:53 +0000 Subject: [PATCH 4/9] Remove unused import --- lib/ex_vision/utils.ex | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index e8813c3..bf8655a 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -1,7 +1,6 @@ defmodule ExVision.Utils do @moduledoc false - import Nx.Defn require Nx require Image alias ExVision.Types From c973b8b399db9766e38f9d5a5e674b3f5df1f2ed Mon Sep 17 00:00:00 2001 From: msluszniak Date: Thu, 4 Jul 2024 06:53:07 +0000 Subject: [PATCH 5/9] Move processing bounding boxes to utils --- .../maskrcnn_resnet50_fpn_v2.ex | 14 ++++----- .../keypointrcnn_resnet50_fpn.ex | 30 +++++++------------ .../object_detection/generic_detector.ex | 14 ++++----- lib/ex_vision/utils.ex | 13 ++++++++ 4 files changed, 33 insertions(+), 38 deletions(-) diff --git a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex index d3ebb33..17d460f 100644 --- a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex +++ b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex @@ -8,6 +8,8 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do require Logger + import ExVision.Utils + alias ExVision.Types.BBoxWithMask @type output_t() :: [BBoxWithMask.t()] @@ -46,16 +48,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do scale_x = w / 224 scale_y = h / 224 - bboxes = - bboxes - |> Nx.squeeze(axes: [0]) - |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) - |> Nx.round() - |> Nx.as_type(:s64) - |> Nx.to_list() + bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y])) - scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list() - labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list() + scores = unbatch(scores) + labels = unbatch(labels) masks = masks diff --git a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex index e8e7dd9..e687616 100644 --- a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex +++ b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex @@ -8,6 +8,8 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do require Logger + import ExVision.Utils + alias ExVision.Types.BBoxWithKeypoints @typep output_t() :: [BBoxWithKeypoints.t()] @@ -67,26 +69,14 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do scale_x = w / 224 scale_y = h / 224 - bboxes = - bboxes - |> Nx.squeeze(axes: [0]) - |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) - |> Nx.round() - |> Nx.as_type(:s64) - |> Nx.to_list() - - scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list() - labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list() - - keypoints_list = - keypoints_list - |> Nx.squeeze(axes: [0]) - |> Nx.multiply(Nx.tensor([scale_x, scale_y, 1])) - |> Nx.round() - |> Nx.as_type(:s64) - |> Nx.to_list() - - keypoints_scores_list = keypoints_scores_list |> Nx.squeeze(axes: [0]) |> Nx.to_list() + bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y])) + + scores = unbatch(scores) + labels = unbatch(labels) + + keypoints_list = process_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1])) + + keypoints_scores_list = unbatch(keypoints_scores_list) [bboxes, scores, labels, keypoints_list, keypoints_scores_list] |> Enum.zip() diff --git a/lib/ex_vision/object_detection/generic_detector.ex b/lib/ex_vision/object_detection/generic_detector.ex index 0214e9b..0c4969d 100644 --- a/lib/ex_vision/object_detection/generic_detector.ex +++ b/lib/ex_vision/object_detection/generic_detector.ex @@ -6,6 +6,8 @@ defmodule ExVision.ObjectDetection.GenericDetector do require Logger + import ExVision.Utils + alias ExVision.Types.{BBox, ImageMetadata} @typep output_t() :: [BBox.t()] @@ -29,16 +31,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do scale_x = w / 224 scale_y = h / 224 - bboxes = - bboxes - |> Nx.squeeze(axes: [0]) - |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) - |> Nx.round() - |> Nx.as_type(:s64) - |> Nx.to_list() + bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y])) - scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list() - labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list() + scores = unbatch(scores) + labels = unbatch(labels) [bboxes, scores, labels] |> Enum.zip() diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index bf8655a..d801618 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -155,4 +155,17 @@ defmodule ExVision.Utils do def batched_run(process_name, input) do process_name |> batched_run([input]) |> hd() end + + defp process_bbox(bbox, scales, axes \\ [0]) do + bbox + |> Nx.squeeze(axes: axes) + |> Nx.multiply(scales) + |> Nx.round() + |> Nx.as_type(:s64) + |> Nx.to_list() + end + + defp unbatch(batched_value, axes \\ [0]) do + batched_value |> Nx.squeeze(axes: axes) |> Nx.to_list() + end end From 294a8cd05368ca42407106418dab0bb1ea771a0c Mon Sep 17 00:00:00 2001 From: msluszniak Date: Thu, 4 Jul 2024 07:01:32 +0000 Subject: [PATCH 6/9] Make added functionalities public --- lib/ex_vision/utils.ex | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index d801618..0cf873e 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -156,7 +156,7 @@ defmodule ExVision.Utils do process_name |> batched_run([input]) |> hd() end - defp process_bbox(bbox, scales, axes \\ [0]) do + def process_bbox(bbox, scales, axes \\ [0]) do bbox |> Nx.squeeze(axes: axes) |> Nx.multiply(scales) @@ -165,7 +165,7 @@ defmodule ExVision.Utils do |> Nx.to_list() end - defp unbatch(batched_value, axes \\ [0]) do + def unbatch(batched_value, axes \\ [0]) do batched_value |> Nx.squeeze(axes: axes) |> Nx.to_list() end end From b90f6cf4d4ad9770c9de059b416e2c983709685b Mon Sep 17 00:00:00 2001 From: msluszniak Date: Thu, 4 Jul 2024 07:28:28 +0000 Subject: [PATCH 7/9] Move credo and remove axes as argument --- .../instance_segmentation/maskrcnn_resnet50_fpn_v2.ex | 4 ++-- .../keypoint_detection/keypointrcnn_resnet50_fpn.ex | 4 ++-- lib/ex_vision/object_detection/generic_detector.ex | 4 ++-- lib/ex_vision/utils.ex | 10 ++++++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex index 17d460f..4c421f5 100644 --- a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex +++ b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex @@ -6,10 +6,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do model: "maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx", categories: "priv/categories/coco_categories.json" - require Logger - import ExVision.Utils + require Logger + alias ExVision.Types.BBoxWithMask @type output_t() :: [BBoxWithMask.t()] diff --git a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex index e687616..ef1d071 100644 --- a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex +++ b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex @@ -6,10 +6,10 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do model: "keypointrcnn_resnet50_fpn_keypoint_detector.onnx", categories: "priv/categories/no_person_or_person.json" - require Logger - import ExVision.Utils + require Logger + alias ExVision.Types.BBoxWithKeypoints @typep output_t() :: [BBoxWithKeypoints.t()] diff --git a/lib/ex_vision/object_detection/generic_detector.ex b/lib/ex_vision/object_detection/generic_detector.ex index 0c4969d..6b6fb05 100644 --- a/lib/ex_vision/object_detection/generic_detector.ex +++ b/lib/ex_vision/object_detection/generic_detector.ex @@ -4,10 +4,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do # Contains a default implementation of pre and post processing for TorchVision detectors # To use: `use ExVision.ObjectDetection.GenericDetector` - require Logger - import ExVision.Utils + require Logger + alias ExVision.Types.{BBox, ImageMetadata} @typep output_t() :: [BBox.t()] diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index 0cf873e..71ad709 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -156,16 +156,18 @@ defmodule ExVision.Utils do process_name |> batched_run([input]) |> hd() end - def process_bbox(bbox, scales, axes \\ [0]) do + @spec process_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()] + def process_bbox(bbox, scales) do bbox - |> Nx.squeeze(axes: axes) + |> Nx.squeeze(axes: [0]) |> Nx.multiply(scales) |> Nx.round() |> Nx.as_type(:s64) |> Nx.to_list() end - def unbatch(batched_value, axes \\ [0]) do - batched_value |> Nx.squeeze(axes: axes) |> Nx.to_list() + @spec unbatch(Nx.Tensor.t()) :: [number()] + def unbatch(batched_value) do + batched_value |> Nx.squeeze(axes: [0]) |> Nx.to_list() end end From df9f8df1883253dd57b4fbf00ee1df586c92141d Mon Sep 17 00:00:00 2001 From: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com> Date: Thu, 4 Jul 2024 11:11:14 +0200 Subject: [PATCH 8/9] Apply suggestions from code review Co-authored-by: mkopcins <120639731+mkopcins@users.noreply.github.com> --- lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex | 2 +- lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex | 2 +- lib/ex_vision/object_detection/generic_detector.ex | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex index 4c421f5..95af79a 100644 --- a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex +++ b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex @@ -48,7 +48,7 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do scale_x = w / 224 scale_y = h / 224 - bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y])) + bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) scores = unbatch(scores) labels = unbatch(labels) diff --git a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex index ef1d071..79eaac5 100644 --- a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex +++ b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex @@ -69,7 +69,7 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do scale_x = w / 224 scale_y = h / 224 - bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y])) + bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) scores = unbatch(scores) labels = unbatch(labels) diff --git a/lib/ex_vision/object_detection/generic_detector.ex b/lib/ex_vision/object_detection/generic_detector.ex index 6b6fb05..a72b0d9 100644 --- a/lib/ex_vision/object_detection/generic_detector.ex +++ b/lib/ex_vision/object_detection/generic_detector.ex @@ -31,7 +31,7 @@ defmodule ExVision.ObjectDetection.GenericDetector do scale_x = w / 224 scale_y = h / 224 - bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y])) + bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) scores = unbatch(scores) labels = unbatch(labels) From 0e1df5eb2a2b0876bca3d74c1fe45d0cf7f7a1df Mon Sep 17 00:00:00 2001 From: msluszniak Date: Thu, 4 Jul 2024 09:24:22 +0000 Subject: [PATCH 9/9] Apply suggestions from code review --- .../instance_segmentation/maskrcnn_resnet50_fpn_v2.ex | 6 +++--- .../keypoint_detection/keypointrcnn_resnet50_fpn.ex | 10 +++++----- lib/ex_vision/object_detection/generic_detector.ex | 6 +++--- lib/ex_vision/utils.ex | 8 ++++---- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex index 95af79a..5197d43 100644 --- a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex +++ b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex @@ -48,10 +48,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do scale_x = w / 224 scale_y = h / 224 - bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) + bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) - scores = unbatch(scores) - labels = unbatch(labels) + scores = squeeze_and_listify(scores) + labels = squeeze_and_listify(labels) masks = masks diff --git a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex index 79eaac5..170728e 100644 --- a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex +++ b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex @@ -69,14 +69,14 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do scale_x = w / 224 scale_y = h / 224 - bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) + bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) - scores = unbatch(scores) - labels = unbatch(labels) + scores = squeeze_and_listify(scores) + labels = squeeze_and_listify(labels) - keypoints_list = process_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1])) + keypoints_list = scale_and_listify_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1])) - keypoints_scores_list = unbatch(keypoints_scores_list) + keypoints_scores_list = squeeze_and_listify(keypoints_scores_list) [bboxes, scores, labels, keypoints_list, keypoints_scores_list] |> Enum.zip() diff --git a/lib/ex_vision/object_detection/generic_detector.ex b/lib/ex_vision/object_detection/generic_detector.ex index a72b0d9..26e5614 100644 --- a/lib/ex_vision/object_detection/generic_detector.ex +++ b/lib/ex_vision/object_detection/generic_detector.ex @@ -31,10 +31,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do scale_x = w / 224 scale_y = h / 224 - bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) + bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) - scores = unbatch(scores) - labels = unbatch(labels) + scores = squeeze_and_listify(scores) + labels = squeeze_and_listify(labels) [bboxes, scores, labels] |> Enum.zip() diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index 71ad709..f3cd541 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -156,8 +156,8 @@ defmodule ExVision.Utils do process_name |> batched_run([input]) |> hd() end - @spec process_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()] - def process_bbox(bbox, scales) do + @spec scale_and_listify_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()] + def scale_and_listify_bbox(bbox, scales) do bbox |> Nx.squeeze(axes: [0]) |> Nx.multiply(scales) @@ -166,8 +166,8 @@ defmodule ExVision.Utils do |> Nx.to_list() end - @spec unbatch(Nx.Tensor.t()) :: [number()] - def unbatch(batched_value) do + @spec squeeze_and_listify(Nx.Tensor.t()) :: [number()] + def squeeze_and_listify(batched_value) do batched_value |> Nx.squeeze(axes: [0]) |> Nx.to_list() end end