Skip to content

Commit

Permalink
Merge branch 'feature/add-another-model-of-the-same-kind' into 'main'
Browse files Browse the repository at this point in the history
Add FasterRCNN Detector

See merge request swm-ai/ex_vision!4
  • Loading branch information
daniel-jodlos committed May 21, 2024
2 parents a960f0b + b3d2e79 commit 2e05a56
Show file tree
Hide file tree
Showing 23 changed files with 250 additions and 102 deletions.
10 changes: 5 additions & 5 deletions .credo.exs
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@
priority: :normal, order: ~w/shortdoc moduledoc behaviour use import require alias/a},
{Credo.Check.Consistency.MultiAliasImportRequireUse, false},
{Credo.Check.Consistency.UnusedVariableNames, force: :meaningful},
{Credo.Check.Design.DuplicatedCode, false},
{Credo.Check.Design.DuplicatedCode, []},
{Credo.Check.Readability.AliasAs, false},
{Credo.Check.Readability.MultiAlias, false},
{Credo.Check.Readability.MultiAlias, []},
{Credo.Check.Readability.Specs, []},
{Credo.Check.Readability.SinglePipe, false},
{Credo.Check.Readability.WithCustomTaggedTuple, false},
Expand All @@ -172,10 +172,10 @@
{Credo.Check.Refactor.DoubleBooleanNegation, false},
{Credo.Check.Refactor.ModuleDependencies, false},
{Credo.Check.Refactor.NegatedIsNil, false},
{Credo.Check.Refactor.PipeChainStart, false},
{Credo.Check.Refactor.PipeChainStart, []},
{Credo.Check.Refactor.VariableRebinding, false},
{Credo.Check.Warning.LeakyEnvironment, false},
{Credo.Check.Warning.MapGetUnsafePass, false},
{Credo.Check.Warning.LeakyEnvironment, []},
{Credo.Check.Warning.MapGetUnsafePass, []},
{Credo.Check.Warning.UnsafeToAtom, false}

#
Expand Down
36 changes: 32 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,38 @@ ExVision will take care of all necessary input transformations and covert output
MobileNetV3.run(model, "example/files/cat.jpg") #=> %{cat: 0.98, dog: 0.01, car: 0.00, ...}
```

ExVision is also capable of accepting tensors on input:
ExVision is also capable of accepting tensors and images on input:

```elixir
cat = "example/files/cat.jpg" |> StbImage.read_file!() |> StbImage.to_nx()
cat = Image.open!("example/files/cat.jpg")
{:ok, cat_tensor} = Image.to_nx(cat)
MobileNetV3.run(model, cat) #=> %{cat: 0.98, dog: 0.01, car: 0.00, ...}
MobileNetV3.run(model, cat_tensor) #=> %{cat: 0.98, dog: 0.01, car: 0.00, ...}
```

### Usage in process workflow

All ExVision models are implemented using `Nx.Serving`.
They are therefore compatible with process workflow.

You can start a model's process:

```elixir
{:ok, pid} = MobileNetV3.start_link(name: MyModel)
```

or start it under the supervision tree

```elixir
{:ok, _supervisor_pid} = Supervisor.start_link([
{MobileNetV3, name: MyModel}
], strategy: :one_for_one)
```

After starting, it's immediatelly available for inference using `batched_run/2` function.

```elixir
MobileNetV3.batched_run(MyModel, cat) #=> %{cat: 0.98, dog: 0.01, car: 0.00, ...}
```

## Installation
Expand All @@ -55,12 +82,13 @@ If the model that you would like to use is missing, feel free to open the issue,
- [x] MobileNetV3 Small
- [ ] EfficientNetV2
- [ ] SqueezeNet
- [ ] Object detection
- [x] Object detection
- [x] SSDLite320 - MobileNetV3 Large backbone
- [x] FasterRCNN ResNet50 FPN
- [x] Semantic segmentation
- [x] DeepLabV3 - MobileNetV3
- [ ] Instance segmentation
- [x] Mask R-CNN
- [ ] Mask R-CNN
- [ ] Keypoint Detection
- [ ] Keypoint R-CNN

Expand Down
5 changes: 2 additions & 3 deletions examples/3-membrane.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ Mix.install(
],
config: [
nx: [default_backend: EXLA.Backend]
],
force: true
]
)
```

Expand Down Expand Up @@ -55,7 +54,7 @@ But before we dive into the code, here are a few tips that will make it both eas
defmodule Membrane.ExVision.Detector do
use Membrane.Filter

alias ExVision.Detection.Ssdlite320_MobileNetv3, as: Model
alias ExVision.Detection.FasterRCNN_ResNet50_FPN, as: Model
alias ExVision.Types.BBox

# Define both input and output pads
Expand Down
17 changes: 12 additions & 5 deletions lib/ex_vision/cache.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,21 @@ defmodule ExVision.Cache do
{:ok, Path.t()} | {:error, reason :: any()}
defp download_file(url, cache_path) do
with :ok <- cache_path |> Path.dirname() |> File.mkdir_p(),
target_file = File.stream!(cache_path),
:ok <- do_download_file(url, target_file) do
if File.exists?(cache_path),
do: {:ok, cache_path},
else: {:error, :download_failed}
tmp_file_path = cache_path <> ".unconfirmed",
tmp_file = File.stream!(tmp_file_path),
:ok <- do_download_file(url, tmp_file),
:ok <- validate_download(tmp_file_path),
:ok <- File.rename(tmp_file_path, cache_path) do
{:ok, cache_path}
end
end

defp validate_download(path) do
if File.exists?(path),
do: :ok,
else: {:error, :download_failed}
end

@spec do_download_file(URI.t(), File.Stream.t()) :: :ok | {:error, reason :: any()}
defp do_download_file(%URI{} = url, %File.Stream{path: target_file_path} = target_file) do
Logger.debug("Downloading file from `#{url}` and saving to `#{target_file_path}`")
Expand Down
2 changes: 1 addition & 1 deletion lib/ex_vision/classification/mobilenet_v3_small.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ defmodule ExVision.Classification.MobileNetV3Small do
end

@impl true
def postprocessing({scores}, _metadata) do
def postprocessing(%{"output" => scores}, _metadata) do
scores
|> Nx.backend_transfer()
|> Nx.flatten()
Expand Down
22 changes: 22 additions & 0 deletions lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
defmodule ExVision.Detection.FasterRCNN_ResNet50_FPN do
@moduledoc """
FasterRCNN object detector with ResNet50 backbone and FPN detection head, exported from torchvision.
"""
use ExVision.Model.Definition.Ortex, base_dir: "detection/fasterrcnn_resnet50_fpn"
use ExVision.Detection.GenericDetector

require Logger

@impl true
def load(options \\ []) do
if Keyword.has_key?(options, :batch_size) do
Logger.warning(
"`:max_batch_size` was given, but this model can only process batch of size 1. Overriding"
)
end

options
|> Keyword.put(:batch_size, 1)
|> default_model_load()
end
end
77 changes: 77 additions & 0 deletions lib/ex_vision/detection/generic_detector.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
defmodule ExVision.Detection.GenericDetector do
@moduledoc false

# Contains a default implementation of pre and post processing for TorchVision detectors
# To use: `use ExVision.Detection.GenericDetector`

require Logger

alias ExVision.Types.{BBox, ImageMetadata}

@typep output_t() :: [BBox.t()]

@spec preprocessing(Nx.Tensor.t(), ImageMetadata.t()) :: Nx.Tensor.t()
def preprocessing(img, _metadata) do
ExVision.Utils.resize(img, {224, 224})
end

@spec postprocessing({Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()}, ImageMetadata.t(), [atom()]) ::
output_t()
def postprocessing(
%{"boxes" => bboxes, "scores" => scores, "labels" => labels},
metadata,
categories
) do
{h, w} = metadata.original_size
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

scores = scores |> Nx.to_list()
labels = labels |> Nx.to_list()

[bboxes, scores, labels]
|> Enum.zip()
|> Enum.filter(fn {_bbox, score, _label} -> score > 0.1 end)
|> Enum.map(fn {[x1, y1, x2, y2], score, label} ->
%BBox{
x1: x1,
x2: x2,
y1: y1,
y2: y2,
score: score,
label: Enum.at(categories, label)
}
end)
end

defmacro __using__(_opts) do
quote do
@typedoc """
A type describing output of `run/2` as a list of a bounding boxes.
Each bounding box describes the location of the object indicated by the `label`.
It also provides the `score` field marking the probability of the prediction.
Bounding boxes with very low scores should most likely be ignored.
"""
@type output_t() :: [BBox.t()]

@impl true
defdelegate preprocessing(image, metadata), to: ExVision.Detection.GenericDetector

@impl true
@spec postprocessing(tuple(), ExVision.Types.ImageMetadata.t()) :: output_t()
def postprocessing(output, metadata) do
ExVision.Detection.GenericDetector.postprocessing(output, metadata, categories())
end

defoverridable preprocessing: 2, postprocessing: 2
end
end
end
48 changes: 1 addition & 47 deletions lib/ex_vision/detection/ssdlite320_mobilenetv3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,10 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3 do
SSDLite320 object detector with MobileNetV3 Large architecture, exported from torchvision.
"""
use ExVision.Model.Definition.Ortex, base_dir: "detection/ssdlite320_mobilenetv3"
use ExVision.Detection.GenericDetector

require Logger

alias ExVision.Types.BBox

@typedoc """
A type describing output of `run/2` as a list of a bounding boxes.
Each bounding box describes the location of the object indicated by the `label`.
It also provides the `score` field marking the probability of the prediction.
Bounding boxes with very low scores should most likely be ignored.
"""
@type output_t() :: [BBox.t(category_t())]

@impl true
def load(options \\ []) do
if Keyword.has_key?(options, :batch_size) do
Expand All @@ -29,40 +19,4 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3 do
|> Keyword.put(:batch_size, 1)
|> default_model_load()
end

@impl true
def preprocessing(img, _metadata) do
ExVision.Utils.resize(img, {224, 224})
end

@impl true
def postprocessing({bboxes, scores, labels}, metadata) do
{h, w} = metadata.original_size
scale_x = w / 224
scale_y = h / 224

bboxes =
bboxes
|> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y]))
|> Nx.round()
|> Nx.as_type(:s64)
|> Nx.to_list()

scores = scores |> Nx.to_list()
labels = labels |> Nx.to_list()

[bboxes, scores, labels]
|> Enum.zip()
|> Enum.filter(fn {_bbox, score, _label} -> score > 0.1 end)
|> Enum.map(fn {[x1, y1, x2, y2], score, label} ->
%BBox{
x1: x1,
x2: x2,
y1: y1,
y2: y2,
score: score,
label: Enum.at(categories(), label)
}
end)
end
end
14 changes: 10 additions & 4 deletions lib/ex_vision/model/definition/ortex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ defmodule ExVision.Model.Definition.Ortex do
end
end

defmacrop get_client_postprocessing(module) do
defmacrop get_client_postprocessing(module, output_names) do
quote do
fn {result, _server_metadata}, metadata ->
result
|> split_onnx_result()
|> split_onnx_result(unquote(output_names))
|> Enum.zip(metadata)
|> Enum.map(fn {result, metadata} -> unquote(module).postprocessing(result, metadata) end)
end
Expand All @@ -93,11 +93,13 @@ defmodule ExVision.Model.Definition.Ortex do
cache_options = Keyword.take(options, [:cache_path, :file_path]),
{:ok, path} <- ExVision.Cache.lazy_get(model_path, cache_options),
{:ok, model} <- do_load_model(path, options[:providers]) do
output_names = ExVision.Utils.onnx_output_names(model)

model
|> then(&Nx.Serving.new(Ortex.Serving, &1))
|> Nx.Serving.batch_size(options[:batch_size])
|> Nx.Serving.client_preprocessing(get_client_preprocessing(module))
|> Nx.Serving.client_postprocessing(get_client_postprocessing(module))
|> Nx.Serving.client_postprocessing(get_client_postprocessing(module, output_names))
|> then(&{:ok, struct!(module, serving: &1)})
end
end
Expand All @@ -113,13 +115,17 @@ defmodule ExVision.Model.Definition.Ortex do
end
end

defp split_onnx_result(tuple) do
defp split_onnx_result(tuple, outputs) do
tuple
|> Tuple.to_list()
|> Enum.map(fn x ->
# Do a backend transfer and also return a list of batches here
x |> Nx.backend_transfer() |> Nx.to_batched(1)
end)
|> Enum.zip()
|> Enum.map(fn parts ->
parts |> Tuple.to_list() |> then(&Enum.zip(outputs, &1)) |> Enum.into(%{})
end)
end

@type using_option_t() :: {:base_dir, Path.t()} | {:name, String.t()}
Expand Down
2 changes: 1 addition & 1 deletion lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ defmodule ExVision.Segmentation.DeepLabV3_MobileNetV3 do
end

@impl true
def postprocessing({out, _aux}, metadata) do
def postprocessing(%{"output" => out}, metadata) do
cls_per_pixel =
out
|> Nx.backend_transfer()
Expand Down
5 changes: 4 additions & 1 deletion lib/ex_vision/types/bbox.ex
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ defmodule ExVision.Types.BBox do
score: number()
}

@typep t() :: t(term())
@typedoc """
Exactly like `t:t/1`, but doesn't put any constraints on the `label` field:
"""
@type t() :: t(term())

@doc """
Return the width of the bounding box
Expand Down
12 changes: 10 additions & 2 deletions lib/ex_vision/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,24 @@ defmodule ExVision.Utils do
def onnx_result_backend_transfer(tuple),
do: tuple |> Tuple.to_list() |> Enum.map(&Nx.backend_transfer/1) |> List.to_tuple()

@spec onnx_input_shape(struct()) :: tuple()
@spec onnx_input_shape(Ortex.Model.t()) :: tuple()
def onnx_input_shape(%Ortex.Model{reference: r}) do
["input", "Float32", shape] =
Ortex.Native.show_session(r)
r
|> Ortex.Native.show_session()
|> Enum.find(fn [name, _type, _shape] -> name == "input" end)
|> hd()

List.to_tuple(shape)
end

@spec onnx_output_names(Ortex.Model.t()) :: [String.t()]
def onnx_output_names(%Ortex.Model{reference: r}) do
{_inputs, outputs} = Ortex.Native.show_session(r)

Enum.map(outputs, fn {name, _type, _shape} -> name end)
end

defn softmax(x) do
Nx.divide(Nx.exp(x), Nx.sum(Nx.exp(x)))
end
Expand Down
Loading

0 comments on commit 2e05a56

Please sign in to comment.