Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace softmax with stable version + small corrections #8

Merged
merged 9 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_l.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_L do
image
|> ExVision.Utils.resize({480, 480})
|> NxImage.normalize(
Nx.tensor([0.5, 0.5, 0.5]),
Nx.tensor([0.5, 0.5, 0.5]),
Nx.f32([0.5, 0.5, 0.5]),
Nx.f32([0.5, 0.5, 0.5]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_m.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_M do
image
|> ExVision.Utils.resize({480, 480})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_s.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_S do
image
|> ExVision.Utils.resize({384, 384})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 1 addition & 3 deletions lib/ex_vision/classification/generic_classifier.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ defmodule ExVision.Classification.GenericClassifier do
# Contains a default implementation of post processing for TorchVision classifiers
# To use: `use ExVision.Classification.GenericClassifier`

alias ExVision.Utils

alias ExVision.Types.ImageMetadata

@typep output_t() :: %{atom() => number()}
Expand All @@ -15,7 +13,7 @@ defmodule ExVision.Classification.GenericClassifier do
scores
|> Nx.backend_transfer()
|> Nx.flatten()
|> Utils.softmax()
|> Axon.Activations.softmax(axis: [0])
|> Nx.to_flat_list()
|> then(&Enum.zip(categories, &1))
|> Map.new()
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/mobilenet_v3_small.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.MobileNetV3Small do
image
|> ExVision.Utils.resize({224, 224})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/squeezenet1_1.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.SqueezeNet1_1 do
image
|> ExVision.Utils.resize({224, 224})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
14 changes: 5 additions & 9 deletions lib/ex_vision/instance_segmentation/maskrcnn_resnet50_fpn_v2.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do

require Logger

import ExVision.Utils

alias ExVision.Types.BBoxWithMask

@type output_t() :: [BBoxWithMask.t()]
Expand Down Expand Up @@ -46,16 +48,10 @@ defmodule ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2 do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
msluszniak marked this conversation as resolved.
Show resolved Hide resolved

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()
scores = unbatch(scores)
labels = unbatch(labels)

masks =
msluszniak marked this conversation as resolved.
Show resolved Hide resolved
masks
Expand Down
30 changes: 10 additions & 20 deletions lib/ex_vision/keypoint_detection/keypointrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do

require Logger

import ExVision.Utils

alias ExVision.Types.BBoxWithKeypoints

@typep output_t() :: [BBoxWithKeypoints.t()]
Expand Down Expand Up @@ -67,26 +69,14 @@ defmodule ExVision.KeypointDetection.KeypointRCNN_ResNet50_FPN do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()

keypoints_list =
keypoints_list
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, 1]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

keypoints_scores_list = keypoints_scores_list |> Nx.squeeze(axes: [0]) |> Nx.to_list()
bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
msluszniak marked this conversation as resolved.
Show resolved Hide resolved

scores = unbatch(scores)
labels = unbatch(labels)

keypoints_list = process_bbox(keypoints_list, Nx.tensor([scale_x, scale_y, 1]))

keypoints_scores_list = unbatch(keypoints_scores_list)

[bboxes, scores, labels, keypoints_list, keypoints_scores_list]
|> Enum.zip()
Expand Down
14 changes: 5 additions & 9 deletions lib/ex_vision/object_detection/generic_detector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ defmodule ExVision.ObjectDetection.GenericDetector do

require Logger

import ExVision.Utils

alias ExVision.Types.{BBox, ImageMetadata}

@typep output_t() :: [BBox.t()]
Expand All @@ -29,16 +31,10 @@ defmodule ExVision.ObjectDetection.GenericDetector do
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.squeeze(axes: [0])
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
bboxes = process_bbox(bboxes, Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
msluszniak marked this conversation as resolved.
Show resolved Hide resolved

scores = scores |> Nx.squeeze(axes: [0]) |> Nx.to_list()
labels = labels |> Nx.squeeze(axes: [0]) |> Nx.to_list()
scores = unbatch(scores)
labels = unbatch(labels)

[bboxes, scores, labels]
|> Enum.zip()
Expand Down
21 changes: 14 additions & 7 deletions lib/ex_vision/utils.ex
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
defmodule ExVision.Utils do
@moduledoc false

import Nx.Defn
require Nx
require Image
alias ExVision.Types
Expand Down Expand Up @@ -86,8 +85,7 @@ defmodule ExVision.Utils do

defp ensure_grad_3(tensor) do
tensor
|> Nx.shape()
|> tuple_size()
|> Nx.rank()
|> case do
3 -> [tensor]
4 -> tensor |> Nx.to_batched(1) |> Stream.map(&Nx.squeeze(&1, axes: [0])) |> Enum.to_list()
Expand Down Expand Up @@ -149,10 +147,6 @@ defmodule ExVision.Utils do
Enum.map(outputs, fn {name, _type, _shape} -> name end)
end

defn softmax(x) do
Nx.divide(Nx.exp(x), Nx.sum(Nx.exp(x)))
end

@spec batched_run(atom(), ExVision.Model.input_t()) :: ExVision.Model.output_t()
def batched_run(process_name, input) when is_list(input) do
Nx.Serving.batched_run(process_name, input)
Expand All @@ -161,4 +155,17 @@ defmodule ExVision.Utils do
def batched_run(process_name, input) do
process_name |> batched_run([input]) |> hd()
end

defp process_bbox(bbox, scales, axes \\ [0]) do
bbox
|> Nx.squeeze(axes: axes)
|> Nx.multiply(scales)
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()
end

defp unbatch(batched_value, axes \\ [0]) do
batched_value |> Nx.squeeze(axes: axes) |> Nx.to_list()
end
end
Loading