diff --git a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex index 95af79a..5197d43 100644 --- a/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex +++ b/lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex @@ -48,10 +48,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do scale_x = w / 224 scale_y = h / 224 - bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) + bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) - scores = unbatch(scores) - labels = unbatch(labels) + scores = squeeze_and_listify(scores) + labels = squeeze_and_listify(labels) masks = masks diff --git a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex index 79eaac5..170728e 100644 --- a/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex +++ b/lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex @@ -69,14 +69,14 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do scale_x = w / 224 scale_y = h / 224 - bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) + bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) - scores = unbatch(scores) - labels = unbatch(labels) + scores = squeeze_and_listify(scores) + labels = squeeze_and_listify(labels) - keypoints_list = process_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1])) + keypoints_list = scale_and_listify_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1])) - keypoints_scores_list = unbatch(keypoints_scores_list) + keypoints_scores_list = squeeze_and_listify(keypoints_scores_list) [bboxes, scores, labels, keypoints_list, keypoints_scores_list] |> Enum.zip() diff --git a/lib/ex_vision/object_detection/generic_detector.ex b/lib/ex_vision/object_detection/generic_detector.ex index a72b0d9..26e5614 100644 --- a/lib/ex_vision/object_detection/generic_detector.ex +++ b/lib/ex_vision/object_detection/generic_detector.ex @@ -31,10 +31,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do scale_x = w / 224 scale_y = h / 224 - bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) + bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y])) - scores = unbatch(scores) - labels = unbatch(labels) + scores = squeeze_and_listify(scores) + labels = squeeze_and_listify(labels) [bboxes, scores, labels] |> Enum.zip() diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index 71ad709..f3cd541 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -156,8 +156,8 @@ defmodule ExVision.Utils do process_name |> batched_run([input]) |> hd() end - @spec process_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()] - def process_bbox(bbox, scales) do + @spec scale_and_listify_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()] + def scale_and_listify_bbox(bbox, scales) do bbox |> Nx.squeeze(axes: [0]) |> Nx.multiply(scales) @@ -166,8 +166,8 @@ defmodule ExVision.Utils do |> Nx.to_list() end - @spec unbatch(Nx.Tensor.t()) :: [number()] - def unbatch(batched_value) do + @spec squeeze_and_listify(Nx.Tensor.t()) :: [number()] + def squeeze_and_listify(batched_value) do batched_value |> Nx.squeeze(axes: [0]) |> Nx.to_list() end end