Skip to content

Commit

Permalink
Remove some magic from handling categories files
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-jodlos committed May 23, 2024
1 parent 6d21ea1 commit 7c04591
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 52 deletions.
2 changes: 1 addition & 1 deletion config/dev.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ config :ortex, Ortex.Native, features: ["coreml"]

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache")
cache_path: System.get_env("EX_VISION_CACHE_DIR", "models")
2 changes: 2 additions & 0 deletions config/test.exs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
import Config

config :ex_vision, cache_path: "models"
4 changes: 3 additions & 1 deletion lib/ex_vision/classification/mobilenet_v3_small.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ defmodule ExVision.Classification.MobileNetV3Small do
Exported from `torchvision`.
Weights from Imagenet 1k.
"""
use ExVision.Model.Definition.Ortex, base_dir: "classification/mobilenetv3small"
use ExVision.Model.Definition.Ortex,
model: "classification/mobilenetv3small/model.onnx",
categories: "classification/mobilenetv3small/categories.json"

require Bunch.Typespec
alias ExVision.Utils
Expand Down
5 changes: 4 additions & 1 deletion lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ defmodule ExVision.Detection.FasterRCNN_ResNet50_FPN do
@moduledoc """
FasterRCNN object detector with ResNet50 backbone and FPN detection head, exported from torchvision.
"""
use ExVision.Model.Definition.Ortex, base_dir: "detection/fasterrcnn_resnet50_fpn"
use ExVision.Model.Definition.Ortex,
model: "detection/fasterrcnn_resnet50_fpn/model.onnx",
categories: "detection/fasterrcnn_resnet50_fpn/categories.json"

use ExVision.Detection.GenericDetector

require Logger
Expand Down
5 changes: 4 additions & 1 deletion lib/ex_vision/detection/ssdlite320_mobilenetv3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3 do
@moduledoc """
SSDLite320 object detector with MobileNetV3 Large architecture, exported from torchvision.
"""
use ExVision.Model.Definition.Ortex, base_dir: "detection/ssdlite320_mobilenetv3"
use ExVision.Model.Definition.Ortex,
model: "detection/ssdlite320_mobilenetv3/model.onnx",
categories: "detection/ssdlite320_mobilenetv3/categories.json"

use ExVision.Detection.GenericDetector

require Logger
Expand Down
42 changes: 6 additions & 36 deletions lib/ex_vision/model/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ defmodule ExVision.Model.Definition do
"""

require Bunch.Typespec
alias ExVision.{Cache, Utils}

@callback load(keyword()) :: {:ok, ExVision.Model.t()} | {:error, reason :: atom()}
@callback run(ExVision.Model.t(), ExVision.Model.input_t()) :: any()
Expand All @@ -23,35 +22,21 @@ defmodule ExVision.Model.Definition do
end)

defmacro __using__(options) do
Application.ensure_all_started(:req)

options =
Keyword.validate!(options, [
:base_dir,
:categories,
name: module_to_name(__CALLER__.module)
])

model_path = Path.join(options[:base_dir], "model.onnx")

categories =
options[:base_dir]
|> Path.join("categories.json")
|> Cache.lazy_get(cache_path: "models")
|> case do
{:ok, categories_file} ->
Utils.load_categories(categories_file)

{:error, _reason} ->
nil
quote do
unless is_nil(unquote(options[:categories])) do
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)
end

categories_spec =
unless is_nil(categories),
do: categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
@behaviour ExVision.Model.Definition

@model_path unquote(model_path)

@derive [ExVision.Model]
@enforce_keys [:serving]
defstruct [:serving]
Expand Down Expand Up @@ -130,21 +115,6 @@ defmodule ExVision.Model.Definition do
child_spec: 0,
start_link: 0,
start_link: 1

unless is_nil(unquote(categories)) do
require Bunch.Typespec

@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(categories_spec)

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
end
end
end
end
16 changes: 5 additions & 11 deletions lib/ex_vision/model/definition/ortex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,14 @@ defmodule ExVision.Model.Definition.Ortex do
@type using_option_t() :: {:base_dir, Path.t()} | {:name, String.t()}
@spec __using__([using_option_t()]) :: Macro.t()
defmacro __using__(opts) do
Application.ensure_all_started(:req)

opts = Keyword.validate!(opts, [:base_dir, :name])
base_dir = opts[:base_dir]

model_path = Path.join(base_dir, "model.onnx")
{opts, generic_opts} = Keyword.split(opts, [:model])
opts = Keyword.validate!(opts, [:model])
model_path = Keyword.fetch!(opts, :model)

quote do
use ExVision.Model.Definition, unquote(Keyword.take(opts, [:base_dir, :name]))
use ExVision.Model.Definition, unquote(generic_opts)
@behaviour ExVision.Model.Definition.Ortex

@model_name unquote(opts[:name])
@model_path unquote(model_path)

@doc """
Creates the model instance
"""
Expand All @@ -156,7 +150,7 @@ defmodule ExVision.Model.Definition.Ortex do
end

defp default_model_load(options) do
ExVision.Model.Definition.Ortex.load_ortex_model(__MODULE__, @model_path, options)
ExVision.Model.Definition.Ortex.load_ortex_model(__MODULE__, unquote(model_path), options)
end

@impl true
Expand Down
39 changes: 39 additions & 0 deletions lib/ex_vision/model/definition/parts/with_categories.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
defmodule ExVision.Model.Definition.Parts.WithCategories do
@moduledoc false
require Logger
alias ExVision.{Cache, Utils}

defp get_categories(file) do
file
|> Cache.lazy_get()
|> case do
{:ok, file} ->
Utils.load_categories(file)

error ->
Logger.error("Failed to load categories from #{file} due to #{inspect(error)}")
raise "Failed to load categories from #{file}"
end
end

defmacro __using__(options) do
options = Keyword.validate!(options, [:name, :categories])
categories = options |> Keyword.fetch!(:categories) |> get_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
require Bunch.Typespec

@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
end
end
end
4 changes: 3 additions & 1 deletion lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ defmodule ExVision.Segmentation.DeepLabV3_MobileNetV3 do
@moduledoc """
A semantic segmentation model for MobileNetV3 Backbone. Exported from torchvision.
"""
use ExVision.Model.Definition.Ortex, base_dir: "segmentation/deeplab_v3_mobilenetv3"
use ExVision.Model.Definition.Ortex,
model: "segmentation/deeplab_v3_mobilenetv3/model.onnx",
categories: "segmentation/deeplab_v3_mobilenetv3/categories.json"

@type output_t() :: %{category_t() => Nx.Tensor.t()}

Expand Down

0 comments on commit 7c04591

Please sign in to comment.