Skip to content

Commit

Permalink
Merge pull request #8 from msluszniak/ms-cleanup
Browse files Browse the repository at this point in the history
Replace softmax with stable version + small corrections
  • Loading branch information
msluszniak authored Jul 4, 2024
2 parents 964760e + 0e1df5e commit 46b111b
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 58 deletions.
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_l.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_L do
image
|> ExVision.Utils.resize({480, 480})
|> NxImage.normalize(
Nx.tensor([0.5, 0.5, 0.5]),
Nx.tensor([0.5, 0.5, 0.5]),
Nx.f32([0.5, 0.5, 0.5]),
Nx.f32([0.5, 0.5, 0.5]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_m.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_M do
image
|> ExVision.Utils.resize({480, 480})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_s.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_S do
image
|> ExVision.Utils.resize({384, 384})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 1 addition & 3 deletions lib/ex_vision/classification/generic_classifier.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ defmodule ExVision.Classification.GenericClassifier do
# Contains a default implementation of post processing for TorchVision classifiers
# To use: `use ExVision.Classification.GenericClassifier`

alias ExVision.Utils

alias ExVision.Types.ImageMetadata

@typep output_t() :: %{atom() => number()}
Expand All @@ -15,7 +13,7 @@ defmodule ExVision.Classification.GenericClassifier do
scores
|> Nx.backend_transfer()
|> Nx.flatten()
|> Utils.softmax()
|> Axon.Activations.softmax(axis: [0])
|> Nx.to_flat_list()
|> then(&Enum.zip(categories, &1))
|> Map.new()
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/mobilenet_v3_small.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.MobileNetV3Small do
image
|> ExVision.Utils.resize({224, 224})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/squeezenet1_1.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.SqueezeNet1_1 do
image
|> ExVision.Utils.resize({224, 224})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
14 changes: 5 additions & 9 deletions lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do
model: "maskrcnn_resnet50_fpn_v2_instance_segmentation.onnx",
categories: "priv/categories/coco_categories.json"

import ExVision.Utils

require Logger

alias ExVision.Types.BBoxWithMask
Expand Down Expand Up @@ -46,16 +48,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()
scores = squeeze_and_listify(scores)
labels = squeeze_and_listify(labels)

masks =
masks
Expand Down
30 changes: 10 additions & 20 deletions lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do
model: "keypointrcnn_resnet50_fpn_keypoint_detector.onnx",
categories: "priv/categories/no_person_or_person.json"

import ExVision.Utils

require Logger

alias ExVision.Types.BBoxWithKeypoints
Expand Down Expand Up @@ -67,26 +69,14 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()

keypoints_list =
keypoints_list
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, 1]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

keypoints_scores_list = keypoints_scores_list |> Nx.squeeze(axes: [0]) |> Nx.to_list()
bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))

scores = squeeze_and_listify(scores)
labels = squeeze_and_listify(labels)

keypoints_list = scale_and_listify_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1]))

keypoints_scores_list = squeeze_and_listify(keypoints_scores_list)

[bboxes, scores, labels, keypoints_list, keypoints_scores_list]
|> Enum.zip()
Expand Down
14 changes: 5 additions & 9 deletions lib/ex_vision/object_detection/generic_detector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ defmodule ExVision.ObjectDetection.GenericDetector do
# Contains a default implementation of pre and post processing for TorchVision detectors
# To use: `use ExVision.ObjectDetection.GenericDetector`

import ExVision.Utils

require Logger

alias ExVision.Types.{BBox, ImageMetadata}
Expand All @@ -29,16 +31,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
bboxes = scale_and_listify_bbox(bboxes, Nx.f32([scale_x, scale_y, scale_x, scale_y]))

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()
scores = squeeze_and_listify(scores)
labels = squeeze_and_listify(labels)

[bboxes, scores, labels]
|> Enum.zip()
Expand Down
23 changes: 16 additions & 7 deletions lib/ex_vision/utils.ex
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
defmodule ExVision.Utils do
@moduledoc false

import Nx.Defn
require Nx
require Image
alias ExVision.Types
Expand Down Expand Up @@ -86,8 +85,7 @@ defmodule ExVision.Utils do

defp ensure_grad_3(tensor) do
tensor
|> Nx.shape()
|> tuple_size()
|> Nx.rank()
|> case do
3 -> [tensor]
4 -> tensor |> Nx.to_batched(1) |> Stream.map(&Nx.squeeze(&1, axes: [0])) |> Enum.to_list()
Expand Down Expand Up @@ -149,10 +147,6 @@ defmodule ExVision.Utils do
Enum.map(outputs, fn {name, _type, _shape} -> name end)
end

defn softmax(x) do
Nx.divide(Nx.exp(x), Nx.sum(Nx.exp(x)))
end

@spec batched_run(atom(), ExVision.Model.input_t()) :: ExVision.Model.output_t()
def batched_run(process_name, input) when is_list(input) do
Nx.Serving.batched_run(process_name, input)
Expand All @@ -161,4 +155,19 @@ defmodule ExVision.Utils do
def batched_run(process_name, input) do
process_name |> batched_run([input]) |> hd()
end

@spec scale_and_listify_bbox(Nx.Tensor.t(), Nx.Tensor.t()) :: [integer()]
def scale_and_listify_bbox(bbox, scales) do
bbox
|> Nx.squeeze(axes: [0])
|> Nx.multiply(scales)
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
end

@spec squeeze_and_listify(Nx.Tensor.t()) :: [number()]
def squeeze_and_listify(batched_value) do
batched_value |> Nx.squeeze(axes: [0]) |> Nx.to_list()
end
end

0 comments on commit 46b111b

Please sign in to comment.