Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak committed Jul 4, 2024
1 parent df9f8df commit 0e1df5e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do
scale_x = w / 224
scale_y = h / 224

bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))
bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))

scores = unbatch(scores)
labels = unbatch(labels)
scores = squeeze_and_listify(scores)
labels = squeeze_and_listify(labels)

masks =
masks
Expand Down
10 changes: 5 additions & 5 deletions lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do
scale_x = w / 224
scale_y = h / 224

bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))
bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))

scores = unbatch(scores)
labels = unbatch(labels)
scores = squeeze_and_listify(scores)
labels = squeeze_and_listify(labels)

keypoints_list = process_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1]))
keypoints_list = scale_and_listify_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1]))

keypoints_scores_list = unbatch(keypoints_scores_list)
keypoints_scores_list = squeeze_and_listify(keypoints_scores_list)

[bboxes, scores, labels, keypoints_list, keypoints_scores_list]
|> Enum.zip()
Expand Down
6 changes: 3 additions & 3 deletions lib/ex_vision/object_detection/generic_detector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do
scale_x = w / 224
scale_y = h / 224

bboxes = process_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))
bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))

scores = unbatch(scores)
labels = unbatch(labels)
scores = squeeze_and_listify(scores)
labels = squeeze_and_listify(labels)

[bboxes, scores, labels]
|> Enum.zip()
Expand Down
8 changes: 4 additions & 4 deletions lib/ex_vision/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ defmodule ExVision.Utils do
process_name |> batched_run([input]) |> hd()
end

@spec process_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()]
def process_bbox(bbox, scales) do
@spec scale_and_listify_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()]
def scale_and_listify_bbox(bbox, scales) do
bbox
|> Nx.squeeze(axes: [0])
|> Nx.multiply(scales)
Expand All @@ -166,8 +166,8 @@ defmodule ExVision.Utils do
|> Nx.to_list()
end

@spec unbatch(Nx.Tensor.t()) :: [number()]
def unbatch(batched_value) do
@spec squeeze_and_listify(Nx.Tensor.t()) :: [number()]
def squeeze_and_listify(batched_value) do
batched_value |> Nx.squeeze(axes: [0]) |> Nx.to_list()
end
end

0 comments on commit 0e1df5e

Please sign in to comment.