diff --git a/lib/ex_vision/classification/efficientnet_v2_l.ex b/lib/ex_vision/classification/efficientnet_v2_l.ex index a4e86db..1990530 100644 --- a/lib/ex_vision/classification/efficientnet_v2_l.ex +++ b/lib/ex_vision/classification/efficientnet_v2_l.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_L do image |> ExVision.Utils.resize({480, 480}) |> NxImage.normalize( - Nx.tensor([0.5, 0.5, 0.5]), - Nx.tensor([0.5, 0.5, 0.5]), + Nx.f32([0.5, 0.5, 0.5]), + Nx.f32([0.5, 0.5, 0.5]), channels: :first ) end diff --git a/lib/ex_vision/classification/efficientnet_v2_m.ex b/lib/ex_vision/classification/efficientnet_v2_m.ex index 83ccafd..eb9ec43 100644 --- a/lib/ex_vision/classification/efficientnet_v2_m.ex +++ b/lib/ex_vision/classification/efficientnet_v2_m.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_M do image |> ExVision.Utils.resize({480, 480}) |> NxImage.normalize( - Nx.tensor([0.485, 0.456, 0.406]), - Nx.tensor([0.229, 0.224, 0.225]), + Nx.f32([0.485, 0.456, 0.406]), + Nx.f32([0.229, 0.224, 0.225]), channels: :first ) end diff --git a/lib/ex_vision/classification/efficientnet_v2_s.ex b/lib/ex_vision/classification/efficientnet_v2_s.ex index 277f9ce..d11da95 100644 --- a/lib/ex_vision/classification/efficientnet_v2_s.ex +++ b/lib/ex_vision/classification/efficientnet_v2_s.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_S do image |> ExVision.Utils.resize({384, 384}) |> NxImage.normalize( - Nx.tensor([0.485, 0.456, 0.406]), - Nx.tensor([0.229, 0.224, 0.225]), + Nx.f32([0.485, 0.456, 0.406]), + Nx.f32([0.229, 0.224, 0.225]), channels: :first ) end diff --git a/lib/ex_vision/classification/generic_classifier.ex b/lib/ex_vision/classification/generic_classifier.ex index 5e7b9d1..0d46f44 100644 --- a/lib/ex_vision/classification/generic_classifier.ex +++ b/lib/ex_vision/classification/generic_classifier.ex @@ -4,8 +4,6 @@ defmodule ExVision.Classification.GenericClassifier do # Contains a default implementation of post processing for TorchVision classifiers # To use: `use ExVision.Classification.GenericClassifier` - alias ExVision.Utils - alias ExVision.Types.ImageMetadata @typep output_t() :: %{atom() => number()} @@ -15,7 +13,7 @@ defmodule ExVision.Classification.GenericClassifier do scores |> Nx.backend_transfer() |> Nx.flatten() - |> Utils.softmax() + |> Axon.Activations.softmax(axis: [0]) |> Nx.to_flat_list() |> then(&Enum.zip(categories, &1)) |> Map.new() diff --git a/lib/ex_vision/classification/mobilenet_v3_small.ex b/lib/ex_vision/classification/mobilenet_v3_small.ex index 2c4d207..e65d0bd 100644 --- a/lib/ex_vision/classification/mobilenet_v3_small.ex +++ b/lib/ex_vision/classification/mobilenet_v3_small.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.MobileNetV3Small do image |> ExVision.Utils.resize({224, 224}) |> NxImage.normalize( - Nx.tensor([0.485, 0.456, 0.406]), - Nx.tensor([0.229, 0.224, 0.225]), + Nx.f32([0.485, 0.456, 0.406]), + Nx.f32([0.229, 0.224, 0.225]), channels: :first ) end diff --git a/lib/ex_vision/classification/squeezenet1_1.ex b/lib/ex_vision/classification/squeezenet1_1.ex index 6865847..3711491 100644 --- a/lib/ex_vision/classification/squeezenet1_1.ex +++ b/lib/ex_vision/classification/squeezenet1_1.ex @@ -15,8 +15,8 @@ defmodule ExVision.Classification.SqueezeNet1_1 do image |> ExVision.Utils.resize({224, 224}) |> NxImage.normalize( - Nx.tensor([0.485, 0.456, 0.406]), - Nx.tensor([0.229, 0.224, 0.225]), + Nx.f32([0.485, 0.456, 0.406]), + Nx.f32([0.229, 0.224, 0.225]), channels: :first ) end diff --git a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex index d3ebb33..5197d43 100644 --- a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex +++ b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex @@ -6,6 +6,8 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do model: "maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx", categories: "priv/categories/coco_categories.json" + import ExVision.Utils + require Logger alias ExVision.Types.BBoxWithMask @@ -46,16 +48,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do scale_x = w / 224 scale_y = h / 224 - bboxes = - bboxes - |> Nx.squeeze(axes: [0]) - |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) - |> Nx.round() - |> Nx.as_type(:s64) - |> Nx.to_list() + bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) - scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list() - labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list() + scores = squeeze_and_listify(scores) + labels = squeeze_and_listify(labels) masks = masks diff --git a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex index e8e7dd9..170728e 100644 --- a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex +++ b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex @@ -6,6 +6,8 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do model: "keypointrcnn_resnet50_fpn_keypoint_detector.onnx", categories: "priv/categories/no_person_or_person.json" + import ExVision.Utils + require Logger alias ExVision.Types.BBoxWithKeypoints @@ -67,26 +69,14 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do scale_x = w / 224 scale_y = h / 224 - bboxes = - bboxes - |> Nx.squeeze(axes: [0]) - |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) - |> Nx.round() - |> Nx.as_type(:s64) - |> Nx.to_list() - - scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list() - labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list() - - keypoints_list = - keypoints_list - |> Nx.squeeze(axes: [0]) - |> Nx.multiply(Nx.tensor([scale_x, scale_y, 1])) - |> Nx.round() - |> Nx.as_type(:s64) - |> Nx.to_list() - - keypoints_scores_list = keypoints_scores_list |> Nx.squeeze(axes: [0]) |> Nx.to_list() + bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) + + scores = squeeze_and_listify(scores) + labels = squeeze_and_listify(labels) + + keypoints_list = scale_and_listify_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1])) + + keypoints_scores_list = squeeze_and_listify(keypoints_scores_list) [bboxes, scores, labels, keypoints_list, keypoints_scores_list] |> Enum.zip() diff --git a/lib/ex_vision/object_detection/generic_detector.ex b/lib/ex_vision/object_detection/generic_detector.ex index 0214e9b..26e5614 100644 --- a/lib/ex_vision/object_detection/generic_detector.ex +++ b/lib/ex_vision/object_detection/generic_detector.ex @@ -4,6 +4,8 @@ defmodule ExVision.ObjectDetection.GenericDetector do # Contains a default implementation of pre and post processing for TorchVision detectors # To use: `use ExVision.ObjectDetection.GenericDetector` + import ExVision.Utils + require Logger alias ExVision.Types.{BBox, ImageMetadata} @@ -29,16 +31,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do scale_x = w / 224 scale_y = h / 224 - bboxes = - bboxes - |> Nx.squeeze(axes: [0]) - |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) - |> Nx.round() - |> Nx.as_type(:s64) - |> Nx.to_list() + bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) - scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list() - labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list() + scores = squeeze_and_listify(scores) + labels = squeeze_and_listify(labels) [bboxes, scores, labels] |> Enum.zip() diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index 15e3482..f3cd541 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -1,7 +1,6 @@ defmodule ExVision.Utils do @moduledoc false - import Nx.Defn require Nx require Image alias ExVision.Types @@ -86,8 +85,7 @@ defmodule ExVision.Utils do defp ensure_grad_3(tensor) do tensor - |> Nx.shape() - |> tuple_size() + |> Nx.rank() |> case do 3 -> [tensor] 4 -> tensor |> Nx.to_batched(1) |> Stream.map(&Nx.squeeze(&1, axes: [0])) |> Enum.to_list() @@ -149,10 +147,6 @@ defmodule ExVision.Utils do Enum.map(outputs, fn {name, _type, _shape} -> name end) end - defn softmax(x) do - Nx.divide(Nx.exp(x), Nx.sum(Nx.exp(x))) - end - @spec batched_run(atom(), ExVision.Model.input_t()) :: ExVision.Model.output_t() def batched_run(process_name, input) when is_list(input) do Nx.Serving.batched_run(process_name, input) @@ -161,4 +155,19 @@ defmodule ExVision.Utils do def batched_run(process_name, input) do process_name |> batched_run([input]) |> hd() end + + @spec scale_and_listify_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()] + def scale_and_listify_bbox(bbox, scales) do + bbox + |> Nx.squeeze(axes: [0]) + |> Nx.multiply(scales) + |> Nx.round() + |> Nx.as_type(:s64) + |> Nx.to_list() + end + + @spec squeeze_and_listify(Nx.Tensor.t()) :: [number()] + def squeeze_and_listify(batched_value) do + batched_value |> Nx.squeeze(axes: [0]) |> Nx.to_list() + end end