From 310fa48611e43a7ec547166d9a04a7365a2dd0e5 Mon Sep 17 00:00:00 2001 From: Mateusz Kopcinski Date: Wed, 14 Aug 2024 15:14:19 +0200 Subject: [PATCH] fixed conditional categories --- lib/ex_vision/model/definition.ex | 6 ++-- .../model/definition/parts/with_categories.ex | 28 ++++++++++--------- .../style_transfer/style_transfer.ex | 2 +- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/lib/ex_vision/model/definition.ex b/lib/ex_vision/model/definition.ex index 65f1b70..2481db3 100644 --- a/lib/ex_vision/model/definition.ex +++ b/lib/ex_vision/model/definition.ex @@ -31,9 +31,9 @@ defmodule ExVision.Model.Definition do ]) quote do - unless is_nil(unquote(options[:categories])) do - use ExVision.Model.Definition.Parts.WithCategories, unquote(options) - end + # conditional defintion based on whether `categories` option is present has to be moved inside __using__ macro + # here is explenation https://cocoa-research.works/2022/10/conditional-compliation-with-if-and-use-in-elixir/ + use ExVision.Model.Definition.Parts.WithCategories, unquote(options) @behaviour ExVision.Model.Definition diff --git a/lib/ex_vision/model/definition/parts/with_categories.ex b/lib/ex_vision/model/definition/parts/with_categories.ex index 81d552b..237d606 100644 --- a/lib/ex_vision/model/definition/parts/with_categories.ex +++ b/lib/ex_vision/model/definition/parts/with_categories.ex @@ -5,22 +5,24 @@ defmodule ExVision.Model.Definition.Parts.WithCategories do defmacro __using__(options) do options = Keyword.validate!(options, [:name, :categories]) - categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories() - spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative() + unless is_nil(options |> Keyword.fetch!(:categories)) do + categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories() + spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative() - quote do - require Bunch.Typespec + quote do + require Bunch.Typespec - @typedoc """ - Type describing all categories recognised by #{unquote(options[:name])} - """ - @type category_t() :: unquote(spec) + @typedoc """ + Type describing all categories recognised by #{unquote(options[:name])} + """ + @type category_t() :: unquote(spec) - @doc """ - Returns a list of all categories recognised by #{unquote(options[:name])} - """ - @spec categories() :: [category_t()] - def categories(), do: unquote(categories) + @doc """ + Returns a list of all categories recognised by #{unquote(options[:name])} + """ + @spec categories() :: [category_t()] + def categories(), do: unquote(categories) + end end end end diff --git a/lib/ex_vision/style_transfer/style_transfer.ex b/lib/ex_vision/style_transfer/style_transfer.ex index 9688c66..84968a4 100644 --- a/lib/ex_vision/style_transfer/style_transfer.ex +++ b/lib/ex_vision/style_transfer/style_transfer.ex @@ -23,7 +23,7 @@ for {module, opts} <- Configuration.configuration() do require Logger @type output_t() :: [Nx.Tensor.t()] - use ExVision.Model.Definition.Ortex, model: unquote(opts[:model]), categories: "priv/categories/coco_categories.json" + use ExVision.Model.Definition.Ortex, model: unquote(opts[:model]) @impl true def load(options \\ []) do