diff --git a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex index 17d460f..4c421f5 100644 --- a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex +++ b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex @@ -6,10 +6,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do model: "maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx", categories: "priv/categories/coco_categories.json" - require Logger - import ExVision.Utils + require Logger + alias ExVision.Types.BBoxWithMask @type output_t() :: [BBoxWithMask.t()] diff --git a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex index e687616..ef1d071 100644 --- a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex +++ b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex @@ -6,10 +6,10 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do model: "keypointrcnn_resnet50_fpn_keypoint_detector.onnx", categories: "priv/categories/no_person_or_person.json" - require Logger - import ExVision.Utils + require Logger + alias ExVision.Types.BBoxWithKeypoints @typep output_t() :: [BBoxWithKeypoints.t()] diff --git a/lib/ex_vision/object_detection/generic_detector.ex b/lib/ex_vision/object_detection/generic_detector.ex index 0c4969d..6b6fb05 100644 --- a/lib/ex_vision/object_detection/generic_detector.ex +++ b/lib/ex_vision/object_detection/generic_detector.ex @@ -4,10 +4,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do # Contains a default implementation of pre and post processing for TorchVision detectors # To use: `use ExVision.ObjectDetection.GenericDetector` - require Logger - import ExVision.Utils + require Logger + alias ExVision.Types.{BBox, ImageMetadata} @typep output_t() :: [BBox.t()] diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index 0cf873e..71ad709 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -156,16 +156,18 @@ defmodule ExVision.Utils do process_name |> batched_run([input]) |> hd() end - def process_bbox(bbox, scales, axes \\ [0]) do + @spec process_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()] + def process_bbox(bbox, scales) do bbox - |> Nx.squeeze(axes: axes) + |> Nx.squeeze(axes: [0]) |> Nx.multiply(scales) |> Nx.round() |> Nx.as_type(:s64) |> Nx.to_list() end - def unbatch(batched_value, axes \\ [0]) do - batched_value |> Nx.squeeze(axes: axes) |> Nx.to_list() + @spec unbatch(Nx.Tensor.t()) :: [number()] + def unbatch(batched_value) do + batched_value |> Nx.squeeze(axes: [0]) |> Nx.to_list() end end