diff --git a/lib/ex_vision/model/definition.ex b/lib/ex_vision/model/definition.ex index 2481db3..79d4469 100644 --- a/lib/ex_vision/model/definition.ex +++ b/lib/ex_vision/model/definition.ex @@ -25,10 +25,10 @@ defmodule ExVision.Model.Definition do Application.ensure_all_started(:req) options = - Keyword.validate!(options, [ + Keyword.validate!(options, categories: nil, name: module_to_name(__CALLER__.module) - ]) + ) quote do # conditional defintion based on whether `categories` option is present has to be moved inside __using__ macro diff --git a/lib/ex_vision/model/definition/parts/with_categories.ex b/lib/ex_vision/model/definition/parts/with_categories.ex index 237d606..2b4481a 100644 --- a/lib/ex_vision/model/definition/parts/with_categories.ex +++ b/lib/ex_vision/model/definition/parts/with_categories.ex @@ -5,6 +5,7 @@ defmodule ExVision.Model.Definition.Parts.WithCategories do defmacro __using__(options) do options = Keyword.validate!(options, [:name, :categories]) + unless is_nil(options |> Keyword.fetch!(:categories)) do categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories() spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative() diff --git a/lib/ex_vision/style_transfer/style_transfer.ex b/lib/ex_vision/style_transfer/style_transfer.ex index b74f3f0..07ba722 100644 --- a/lib/ex_vision/style_transfer/style_transfer.ex +++ b/lib/ex_vision/style_transfer/style_transfer.ex @@ -1,18 +1,26 @@ defmodule Configuration do @moduledoc false - @low_resolution {400,300} - @high_resolution {640,480} + @low_resolution {400, 300} + @high_resolution {640, 480} + + @spec configuration() :: %{} def configuration do %{ - ExVision.StyleTransfer.Candy => [model: "candy.onnx", resolution: @high_resolution], - ExVision.StyleTransfer.CandyFast => [model: "candy_fast.onnx", resolution: @low_resolution], - ExVision.StyleTransfer.Princess => [model: "princess.onnx", resolution: @high_resolution], - ExVision.StyleTransfer.PrincessFast => [model: "princess_fast.onnx", resolution: @low_resolution], - ExVision.StyleTransfer.Udnie => [model: "udnie.onnx", resolution: @high_resolution], - ExVision.StyleTransfer.UdnieFast => [model: "udnie_fast.onnx", resolution: @low_resolution], - ExVision.StyleTransfer.Mosaic => [model: "mosaic.onnx", resolution: @high_resolution], - ExVision.StyleTransfer.MosaicFast => [model: "mosaic_fast.onnx", resolution: @low_resolution], + ExVision.StyleTransfer.Candy => [model: "candy.onnx", resolution: @high_resolution], + ExVision.StyleTransfer.CandyFast => [model: "candy_fast.onnx", resolution: @low_resolution], + ExVision.StyleTransfer.Princess => [model: "princess.onnx", resolution: @high_resolution], + ExVision.StyleTransfer.PrincessFast => [ + model: "princess_fast.onnx", + resolution: @low_resolution + ], + ExVision.StyleTransfer.Udnie => [model: "udnie.onnx", resolution: @high_resolution], + ExVision.StyleTransfer.UdnieFast => [model: "udnie_fast.onnx", resolution: @low_resolution], + ExVision.StyleTransfer.Mosaic => [model: "mosaic.onnx", resolution: @high_resolution], + ExVision.StyleTransfer.MosaicFast => [ + model: "mosaic_fast.onnx", + resolution: @low_resolution + ] } end end @@ -22,6 +30,8 @@ for {module, opts} <- Configuration.configuration() do @moduledoc """ #{module} is a custom style transfer model optimised for devices with low computational capabilities and CPU inference. """ + use ExVision.Model.Definition.Ortex, model: unquote(opts[:model]) + require Logger @typedoc """ @@ -29,8 +39,6 @@ for {module, opts} <- Configuration.configuration() do """ @type output_t() :: Nx.Tensor.t() - use ExVision.Model.Definition.Ortex, model: unquote(opts[:model]) - @impl true def load(options \\ []) do if Keyword.has_key?(options, :batch_size) do @@ -54,14 +62,15 @@ for {module, opts} <- Configuration.configuration() do stylized_frame, metadata ) do - {h,w} = unquote(opts[:resolution]) + {h, w} = unquote(opts[:resolution]) + stylized_frame["55"] - |> Nx.reshape({3, h, w}, names: [:channel, :height, :width]) - |> NxImage.resize(metadata.original_size, channels: :first) - |> Nx.max(0.0) - |> Nx.min(255.0) - |> Nx.as_type(:u8) - |> Nx.transpose(axes: [1, 2, 0]) + |> Nx.reshape({3, h, w}, names: [:channel, :height, :width]) + |> NxImage.resize(metadata.original_size, channels: :first) + |> Nx.max(0.0) + |> Nx.min(255.0) + |> Nx.as_type(:u8) + |> Nx.transpose(axes: [1, 2, 0]) end end end diff --git a/mix.lock b/mix.lock index 9931294..c45bba9 100644 --- a/mix.lock +++ b/mix.lock @@ -6,7 +6,7 @@ "cc_precompiler": {:hex, :cc_precompiler, "0.1.10", "47c9c08d8869cf09b41da36538f62bc1abd3e19e41701c2cea2675b53c704258", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f6e046254e53cd6b41c6bacd70ae728011aa82b2742a80d6e2214855c6e06b22"}, "coerce": {:hex, :coerce, "1.0.1", "211c27386315dc2894ac11bc1f413a0e38505d808153367bd5c6e75a4003d096", [:mix], [], "hexpm", "b44a691700f7a1a15b4b7e2ff1fa30bebd669929ac8aa43cffe9e2f8bf051cf1"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, - "credo": {:hex, :credo, "1.7.5", "643213503b1c766ec0496d828c90c424471ea54da77c8a168c725686377b9545", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "f799e9b5cd1891577d8c773d245668aa74a2fcd15eb277f51a0131690ebfb3fd"}, + "credo": {:hex, :credo, "1.7.7", "771445037228f763f9b2afd612b6aa2fd8e28432a95dbbc60d8e03ce71ba4446", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "8bc87496c9aaacdc3f90f01b7b0582467b69b4bd2441fe8aae3109d843cc2f2e"}, "dialyxir": {:hex, :dialyxir, "1.4.3", "edd0124f358f0b9e95bfe53a9fcf806d615d8f838e2202a9f430d59566b6b53b", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "bf2cfb75cd5c5006bec30141b131663299c661a864ec7fbbc72dfa557487a986"}, "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, @@ -14,11 +14,11 @@ "evision": {:hex, :evision, "0.1.38", "f8b23ad685c3ebd70969a3457027b5c74b5bc8dc51588661c516098c3240b92d", [:make, :mix, :rebar3], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}, {:progress_bar, "~> 2.0 or ~> 3.0", [hex: :progress_bar, repo: "hexpm", optional: true]}], "hexpm", "f9302547d76c5e4ad7022ffdc76be13e33c990fdd67ad2af203f24ab5d3aee20"}, "ex_doc": {:hex, :ex_doc, "0.32.1", "21e40f939515373bcdc9cffe65f3b3543f05015ac6c3d01d991874129d173420", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "5142c9db521f106d61ff33250f779807ed2a88620e472ac95dc7d59c380113da"}, "exla": {:hex, :exla, "0.7.2", "8ac573093df8e5e6b36845beeb3f5a0ea92b05082bf2fa4678f80170cfc887f6", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d061ea87858415e5585cbd4b7bdae5489000339519a2c6a7f51eb0defd73b588"}, - "file_system": {:hex, :file_system, "1.0.0", "b689cc7dcee665f774de94b5a832e578bd7963c8e637ef940cd44327db7de2cd", [:mix], [], "hexpm", "6752092d66aec5a10e662aefeed8ddb9531d79db0bc145bb8c40325ca1d8536d"}, + "file_system": {:hex, :file_system, "1.0.1", "79e8ceaddb0416f8b8cd02a0127bdbababe7bf4a23d2a395b983c1f8b3f73edd", [:mix], [], "hexpm", "4414d1f38863ddf9120720cd976fce5bdde8e91d8283353f0e31850fa89feb9e"}, "finch": {:hex, :finch, "0.18.0", "944ac7d34d0bd2ac8998f79f7a811b21d87d911e77a786bc5810adb75632ada4", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "69f5045b042e531e53edc2574f15e25e735b522c37e2ddb766e15b979e03aa65"}, "hpax": {:hex, :hpax, "0.2.0", "5a58219adcb75977b2edce5eb22051de9362f08236220c9e859a47111c194ff5", [:mix], [], "hexpm", "bea06558cdae85bed075e6c036993d43cd54d447f76d8190a8db0dc5893fa2f1"}, "image": {:hex, :image, "0.44.0", "e8eea9398abbed12b7784e786f26a5c839a00bcddd8f2f8ba12adf7e227beb9f", [:mix], [{:bumblebee, "~> 0.3", [hex: :bumblebee, repo: "hexpm", optional: true]}, {:evision, "~> 0.1.33", [hex: :evision, repo: "hexpm", optional: true]}, {:exla, "~> 0.5", [hex: :exla, repo: "hexpm", optional: true]}, {:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: true]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: true]}, {:nx_image, "~> 0.1", [hex: :nx_image, repo: "hexpm", optional: true]}, {:phoenix_html, "~> 2.1 or ~> 3.2 or ~> 4.0", [hex: :phoenix_html, repo: "hexpm", optional: false]}, {:plug, "~> 1.13", [hex: :plug, repo: "hexpm", optional: true]}, {:req, "~> 0.4", [hex: :req, repo: "hexpm", optional: true]}, {:rustler, "> 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:sweet_xml, "~> 0.7", [hex: :sweet_xml, repo: "hexpm", optional: false]}, {:vix, "~> 0.23", [hex: :vix, repo: "hexpm", optional: false]}], "hexpm", "cd00a3de4d7a40a2cb1ca72b9852b0d81701793414af8babf4d33dbeb6de0f6f"}, - "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, + "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.5", "e0ff5a7c708dda34311f7522a8758e23bfcd7d8d8068dc312b5eb41c6fd76eba", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "94d2e986428585a21516d7d7149781480013c56e30c6a233534bedf38867a59a"}, diff --git a/test/ex_vision/style_transfer/candy_test.exs b/test/ex_vision/style_transfer/candy_test.exs deleted file mode 100644 index 4766dc1..0000000 --- a/test/ex_vision/style_transfer/candy_test.exs +++ /dev/null @@ -1,30 +0,0 @@ -defmodule Configuration do - def configuration do - %{ - ExVision.StyleTransfer.CandyTest => [module: ExVision.StyleTransfer.Candy, gt_file: "cat_candy.gt"], - ExVision.StyleTransfer.CandyFastTest => [module: ExVision.StyleTransfer.CandyFast, gt_file: "cat_candy_fast.gt"], - ExVision.StyleTransfer.PrincessTest => [module: ExVision.StyleTransfer.Princess, gt_file: "cat_princess.gt"], - ExVision.StyleTransfer.PrincessFastTest => [module: ExVision.StyleTransfer.PrincessFast, gt_file: "cat_princess_fast.gt"], - ExVision.StyleTransfer.UdnieTest => [module: ExVision.StyleTransfer.Udnie, gt_file: "cat_udnie.gt"], - ExVision.StyleTransfer.UdnieFastTest => [module: ExVision.StyleTransfer.UdnieFast, gt_file: "cat_udnie_fast.gt"], - ExVision.StyleTransfer.MosaicTest => [module: ExVision.StyleTransfer.Mosaic, gt_file: "cat_mosaic.gt"], - ExVision.StyleTransfer.MosaicFastTest => [module: ExVision.StyleTransfer.MosaicFast, gt_file: "cat_mosaic_fast.gt"], - } - end -end - -for {module, opts} <- Configuration.configuration() do - - defmodule module do - use ExVision.Model.Case, module: unquote(opts[:module]) - use ExVision.TestUtils - - @impl true - def test_inference_result(result) do - expected_result = "test/assets/results/style_transfer/#{unquote(opts[:gt_file])}" - |> File.read!() - |> Nx.deserialize() - assert_tensors_equal(result, expected_result) - end - end -end diff --git a/test/ex_vision/style_transfer/style_transfer_test.exs b/test/ex_vision/style_transfer/style_transfer_test.exs new file mode 100644 index 0000000..f0e4a33 --- /dev/null +++ b/test/ex_vision/style_transfer/style_transfer_test.exs @@ -0,0 +1,56 @@ +defmodule TestConfiguration do + @spec configuration() :: %{} + def configuration do + %{ + ExVision.StyleTransfer.CandyTest => [ + module: ExVision.StyleTransfer.Candy, + gt_file: "cat_candy.gt" + ], + ExVision.StyleTransfer.CandyFastTest => [ + module: ExVision.StyleTransfer.CandyFast, + gt_file: "cat_candy_fast.gt" + ], + ExVision.StyleTransfer.PrincessTest => [ + module: ExVision.StyleTransfer.Princess, + gt_file: "cat_princess.gt" + ], + ExVision.StyleTransfer.PrincessFastTest => [ + module: ExVision.StyleTransfer.PrincessFast, + gt_file: "cat_princess_fast.gt" + ], + ExVision.StyleTransfer.UdnieTest => [ + module: ExVision.StyleTransfer.Udnie, + gt_file: "cat_udnie.gt" + ], + ExVision.StyleTransfer.UdnieFastTest => [ + module: ExVision.StyleTransfer.UdnieFast, + gt_file: "cat_udnie_fast.gt" + ], + ExVision.StyleTransfer.MosaicTest => [ + module: ExVision.StyleTransfer.Mosaic, + gt_file: "cat_mosaic.gt" + ], + ExVision.StyleTransfer.MosaicFastTest => [ + module: ExVision.StyleTransfer.MosaicFast, + gt_file: "cat_mosaic_fast.gt" + ] + } + end +end + +for {module, opts} <- TestConfiguration.configuration() do + defmodule module do + use ExVision.Model.Case, module: unquote(opts[:module]) + use ExVision.TestUtils + + @impl true + def test_inference_result(result) do + expected_result = + "test/assets/results/style_transfer/#{unquote(opts[:gt_file])}" + |> File.read!() + |> Nx.deserialize() + + assert_tensors_equal(result, expected_result) + end + end +end diff --git a/test/support/exvision/test_utils.ex b/test/support/exvision/test_utils.ex index cac871a..206c296 100644 --- a/test/support/exvision/test_utils.ex +++ b/test/support/exvision/test_utils.ex @@ -44,7 +44,10 @@ defmodule ExVision.TestUtils do defmacro assert_tensors_equal(a, b, delta \\ @default_delta) do quote do - assert Nx.all_close(unquote(a), unquote(b), atol: unquote(delta)) |> Nx.reduce_min() |> Nx.to_number() == 1 + assert unquote(a) + |> Nx.all_close(unquote(b), atol: unquote(delta)) + |> Nx.reduce_min() + |> Nx.to_number() == 1 end end