Skip to content

Commit

Permalink
fixed conditional categories
Browse files Browse the repository at this point in the history
  • Loading branch information
Mateusz Kopcinski authored and Mateusz Kopcinski committed Aug 14, 2024
1 parent d348452 commit 310fa48
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
6 changes: 3 additions & 3 deletions lib/ex_vision/model/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ defmodule ExVision.Model.Definition do
])

quote do
unless is_nil(unquote(options[:categories])) do
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)
end
# conditional defintion based on whether `categories` option is present has to be moved inside __using__ macro
# here is explenation https://cocoa-research.works/2022/10/conditional-compliation-with-if-and-use-in-elixir/
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)

@behaviour ExVision.Model.Definition

Expand Down
28 changes: 15 additions & 13 deletions lib/ex_vision/model/definition/parts/with_categories.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@ defmodule ExVision.Model.Definition.Parts.WithCategories do

defmacro __using__(options) do
options = Keyword.validate!(options, [:name, :categories])
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()
unless is_nil(options |> Keyword.fetch!(:categories)) do
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
require Bunch.Typespec
quote do
require Bunch.Typespec

@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)
@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
end
end
end
end
2 changes: 1 addition & 1 deletion lib/ex_vision/style_transfer/style_transfer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ for {module, opts} <- Configuration.configuration() do
require Logger
@type output_t() :: [Nx.Tensor.t()]

use ExVision.Model.Definition.Ortex, model: unquote(opts[:model]), categories: "priv/categories/coco_categories.json"
use ExVision.Model.Definition.Ortex, model: unquote(opts[:model])

Check warning on line 26 in lib/ex_vision/style_transfer/style_transfer.ex

View workflow job for this annotation

GitHub Actions / Build and test

use must appear before type

@impl true
def load(options \\ []) do
Expand Down

0 comments on commit 310fa48

Please sign in to comment.