Skip to content

Commit

Permalink
Move credo and remove axes as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak committed Jul 4, 2024
1 parent 294a8cd commit b90f6cf
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do
model: "maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx",
categories: "priv/categories/coco_categories.json"

require Logger

import ExVision.Utils

require Logger

alias ExVision.Types.BBoxWithMask

@type output_t() :: [BBoxWithMask.t()]
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do
model: "keypointrcnn_resnet50_fpn_keypoint_detector.onnx",
categories: "priv/categories/no_person_or_person.json"

require Logger

import ExVision.Utils

require Logger

alias ExVision.Types.BBoxWithKeypoints

@typep output_t() :: [BBoxWithKeypoints.t()]
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/object_detection/generic_detector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do
# Contains a default implementation of pre and post processing for TorchVision detectors
# To use: `use ExVision.ObjectDetection.GenericDetector`

require Logger

import ExVision.Utils

require Logger

alias ExVision.Types.{BBox, ImageMetadata}

@typep output_t() :: [BBox.t()]
Expand Down
10 changes: 6 additions & 4 deletions lib/ex_vision/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,18 @@ defmodule ExVision.Utils do
process_name |> batched_run([input]) |> hd()
end

def process_bbox(bbox, scales, axes \\ [0]) do
@spec process_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()]
def process_bbox(bbox, scales) do
bbox
|> Nx.squeeze(axes: axes)
|> Nx.squeeze(axes: [0])
|> Nx.multiply(scales)
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
end

def unbatch(batched_value, axes \\ [0]) do
batched_value |> Nx.squeeze(axes: axes) |> Nx.to_list()
@spec unbatch(Nx.Tensor.t()) :: [number()]
def unbatch(batched_value) do
batched_value |> Nx.squeeze(axes: [0]) |> Nx.to_list()
end
end

0 comments on commit b90f6cf

Please sign in to comment.