Skip to content

Commit

Permalink
Cleanup categories (#2)
Browse files Browse the repository at this point in the history
* Fix DeepLabV3 models

* Remove some magic from handling categories files

* Flatten the structure of the model directory
  • Loading branch information
daniel-jodlos authored May 23, 2024
1 parent 553b021 commit c1f811c
Show file tree
Hide file tree
Showing 25 changed files with 1,192 additions and 64 deletions.
4 changes: 4 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
models/**/*.onnx filter=lfs diff=lfs merge=lfs -text
models/deeplab_v3_mobilenetv3_segmentation.onnx filter=lfs diff=lfs merge=lfs -text
models/fasterrcnn_resnet50_fpn_detector.onnx filter=lfs diff=lfs merge=lfs -text
models/mobilenetv3small-classifier.onnx filter=lfs diff=lfs merge=lfs -text
models/ssdlite320_mobilenetv3_detector.onnx filter=lfs diff=lfs merge=lfs -text
2 changes: 1 addition & 1 deletion config/dev.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ config :ortex, Ortex.Native, features: ["coreml"]

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache")
cache_path: System.get_env("EX_VISION_CACHE_DIR", "models")
2 changes: 2 additions & 0 deletions config/test.exs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
import Config

config :ex_vision, cache_path: "models"
4 changes: 3 additions & 1 deletion lib/ex_vision/classification/mobilenet_v3_small.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ defmodule ExVision.Classification.MobileNetV3Small do
Exported from `torchvision`.
Weights from Imagenet 1k.
"""
use ExVision.Model.Definition.Ortex, base_dir: "classification/mobilenetv3small"
use ExVision.Model.Definition.Ortex,
model: "mobilenetv3small-classifier.onnx",
categories: "imagenet_v2_categories.json"

require Bunch.Typespec
alias ExVision.Utils
Expand Down
5 changes: 4 additions & 1 deletion lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ defmodule ExVision.Detection.FasterRCNN_ResNet50_FPN do
@moduledoc """
FasterRCNN object detector with ResNet50 backbone and FPN detection head, exported from torchvision.
"""
use ExVision.Model.Definition.Ortex, base_dir: "detection/fasterrcnn_resnet50_fpn"
use ExVision.Model.Definition.Ortex,
model: "fasterrcnn_resnet50_fpn_detector.onnx",
categories: "coco_categories.json"

use ExVision.Detection.GenericDetector

require Logger
Expand Down
3 changes: 1 addition & 2 deletions lib/ex_vision/detection/generic_detector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ defmodule ExVision.Detection.GenericDetector do
ExVision.Utils.resize(img, {224, 224})
end

@spec postprocessing({Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()}, ImageMetadata.t(), [atom()]) ::
output_t()
@spec postprocessing(map(), ImageMetadata.t(), [atom()]) :: output_t()
def postprocessing(
%{"boxes" => bboxes, "scores" => scores, "labels" => labels},
metadata,
Expand Down
5 changes: 4 additions & 1 deletion lib/ex_vision/detection/ssdlite320_mobilenetv3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3 do
@moduledoc """
SSDLite320 object detector with MobileNetV3 Large architecture, exported from torchvision.
"""
use ExVision.Model.Definition.Ortex, base_dir: "detection/ssdlite320_mobilenetv3"
use ExVision.Model.Definition.Ortex,
model: "ssdlite320_mobilenetv3_detector.onnx",
categories: "coco_categories.json"

use ExVision.Detection.GenericDetector

require Logger
Expand Down
42 changes: 6 additions & 36 deletions lib/ex_vision/model/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ defmodule ExVision.Model.Definition do
"""

require Bunch.Typespec
alias ExVision.{Cache, Utils}

@callback load(keyword()) :: {:ok, ExVision.Model.t()} | {:error, reason :: atom()}
@callback run(ExVision.Model.t(), ExVision.Model.input_t()) :: any()
Expand All @@ -23,35 +22,21 @@ defmodule ExVision.Model.Definition do
end)

defmacro __using__(options) do
Application.ensure_all_started(:req)

options =
Keyword.validate!(options, [
:base_dir,
:categories,
name: module_to_name(__CALLER__.module)
])

model_path = Path.join(options[:base_dir], "model.onnx")

categories =
options[:base_dir]
|> Path.join("categories.json")
|> Cache.lazy_get(cache_path: "models")
|> case do
{:ok, categories_file} ->
Utils.load_categories(categories_file)

{:error, _reason} ->
nil
quote do
unless is_nil(unquote(options[:categories])) do
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)
end

categories_spec =
unless is_nil(categories),
do: categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
@behaviour ExVision.Model.Definition

@model_path unquote(model_path)

@derive [ExVision.Model]
@enforce_keys [:serving]
defstruct [:serving]
Expand Down Expand Up @@ -130,21 +115,6 @@ defmodule ExVision.Model.Definition do
child_spec: 0,
start_link: 0,
start_link: 1

unless is_nil(unquote(categories)) do
require Bunch.Typespec

@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(categories_spec)

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
end
end
end
end
18 changes: 6 additions & 12 deletions lib/ex_vision/model/definition/ortex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ defmodule ExVision.Model.Definition.Ortex do
In this callback, you should transform the output to match your desired format.
"""
@callback postprocessing(tuple(), ImageMetadata.t()) :: ExVision.Model.output_t()
@callback postprocessing(map(), ImageMetadata.t()) :: ExVision.Model.output_t()

@typedoc """
A type describing ONNX provider that can be used with ExVision.
Expand Down Expand Up @@ -131,20 +131,14 @@ defmodule ExVision.Model.Definition.Ortex do
@type using_option_t() :: {:base_dir, Path.t()} | {:name, String.t()}
@spec __using__([using_option_t()]) :: Macro.t()
defmacro __using__(opts) do
Application.ensure_all_started(:req)

opts = Keyword.validate!(opts, [:base_dir, :name])
base_dir = opts[:base_dir]

model_path = Path.join(base_dir, "model.onnx")
{opts, generic_opts} = Keyword.split(opts, [:model])
opts = Keyword.validate!(opts, [:model])
model_path = Keyword.fetch!(opts, :model)

quote do
use ExVision.Model.Definition, unquote(Keyword.take(opts, [:base_dir, :name]))
use ExVision.Model.Definition, unquote(generic_opts)
@behaviour ExVision.Model.Definition.Ortex

@model_name unquote(opts[:name])
@model_path unquote(model_path)

@doc """
Creates the model instance
"""
Expand All @@ -156,7 +150,7 @@ defmodule ExVision.Model.Definition.Ortex do
end

defp default_model_load(options) do
ExVision.Model.Definition.Ortex.load_ortex_model(__MODULE__, @model_path, options)
ExVision.Model.Definition.Ortex.load_ortex_model(__MODULE__, unquote(model_path), options)
end

@impl true
Expand Down
39 changes: 39 additions & 0 deletions lib/ex_vision/model/definition/parts/with_categories.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
defmodule ExVision.Model.Definition.Parts.WithCategories do
@moduledoc false
require Logger
alias ExVision.{Cache, Utils}

defp get_categories(file) do
file
|> Cache.lazy_get()
|> case do
{:ok, file} ->
Utils.load_categories(file)

error ->
Logger.error("Failed to load categories from #{file} due to #{inspect(error)}")
raise "Failed to load categories from #{file}"
end
end

defmacro __using__(options) do
options = Keyword.validate!(options, [:name, :categories])
categories = options |> Keyword.fetch!(:categories) |> get_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
require Bunch.Typespec

@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
end
end
end
4 changes: 3 additions & 1 deletion lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ defmodule ExVision.Segmentation.DeepLabV3_MobileNetV3 do
@moduledoc """
A semantic segmentation model for MobileNetV3 Backbone. Exported from torchvision.
"""
use ExVision.Model.Definition.Ortex, base_dir: "segmentation/deeplabv3_mobilenetv3"
use ExVision.Model.Definition.Ortex,
model: "deeplab_v3_mobilenetv3_segmentation.onnx",
categories: "coco_with_voc_labels_categories.json"

@type output_t() :: %{category_t() => Nx.Tensor.t()}

Expand Down
Loading

0 comments on commit c1f811c

Please sign in to comment.