Skip to content

Commit

Permalink
fix credo, fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
Mateusz Kopcinski authored and Mateusz Kopcinski committed Aug 20, 2024
1 parent 87972af commit fdb204a
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 55 deletions.
4 changes: 2 additions & 2 deletions lib/ex_vision/model/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ defmodule ExVision.Model.Definition do
Application.ensure_all_started(:req)

options =
Keyword.validate!(options, [
Keyword.validate!(options,
categories: nil,
name: module_to_name(__CALLER__.module)
])
)

quote do
# conditional defintion based on whether `categories` option is present has to be moved inside __using__ macro
Expand Down
1 change: 1 addition & 0 deletions lib/ex_vision/model/definition/parts/with_categories.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ defmodule ExVision.Model.Definition.Parts.WithCategories do

defmacro __using__(options) do
options = Keyword.validate!(options, [:name, :categories])

unless is_nil(options |> Keyword.fetch!(:categories)) do
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()
Expand Down
47 changes: 28 additions & 19 deletions lib/ex_vision/style_transfer/style_transfer.ex
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
defmodule Configuration do
@moduledoc false

@low_resolution {400,300}
@high_resolution {640,480}
@low_resolution {400, 300}
@high_resolution {640, 480}

@spec configuration() :: %{}
def configuration do
%{
ExVision.StyleTransfer.Candy => [model: "candy.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.CandyFast => [model: "candy_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Princess => [model: "princess.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.PrincessFast => [model: "princess_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Udnie => [model: "udnie.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.UdnieFast => [model: "udnie_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Mosaic => [model: "mosaic.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.MosaicFast => [model: "mosaic_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Candy => [model: "candy.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.CandyFast => [model: "candy_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Princess => [model: "princess.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.PrincessFast => [
model: "princess_fast.onnx",
resolution: @low_resolution
],
ExVision.StyleTransfer.Udnie => [model: "udnie.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.UdnieFast => [model: "udnie_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Mosaic => [model: "mosaic.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.MosaicFast => [
model: "mosaic_fast.onnx",
resolution: @low_resolution
]
}
end
end
Expand All @@ -22,15 +30,15 @@ for {module, opts} <- Configuration.configuration() do
@moduledoc """
#{module} is a custom style transfer model optimised for devices with low computational capabilities and CPU inference.
"""
use ExVision.Model.Definition.Ortex, model: unquote(opts[:model])

require Logger

@typedoc """
A type consisting of output tesnor (stylized image tensor) from style transfer models of shape {#{Enum.join(Tuple.to_list(opts[:resolution]) ++ [3], ", ")}}.
"""
@type output_t() :: Nx.Tensor.t()

use ExVision.Model.Definition.Ortex, model: unquote(opts[:model])

@impl true
def load(options \\ []) do
if Keyword.has_key?(options, :batch_size) do
Expand All @@ -54,14 +62,15 @@ for {module, opts} <- Configuration.configuration() do
stylized_frame,
metadata
) do
{h,w} = unquote(opts[:resolution])
{h, w} = unquote(opts[:resolution])

stylized_frame["55"]
|> Nx.reshape({3, h, w}, names: [:channel, :height, :width])
|> NxImage.resize(metadata.original_size, channels: :first)
|> Nx.max(0.0)
|> Nx.min(255.0)
|> Nx.as_type(:u8)
|> Nx.transpose(axes: [1, 2, 0])
|> Nx.reshape({3, h, w}, names: [:channel, :height, :width])
|> NxImage.resize(metadata.original_size, channels: :first)
|> Nx.max(0.0)
|> Nx.min(255.0)
|> Nx.as_type(:u8)
|> Nx.transpose(axes: [1, 2, 0])
end
end
end
6 changes: 3 additions & 3 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
"cc_precompiler": {:hex, :cc_precompiler, "0.1.10", "47c9c08d8869cf09b41da36538f62bc1abd3e19e41701c2cea2675b53c704258", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f6e046254e53cd6b41c6bacd70ae728011aa82b2742a80d6e2214855c6e06b22"},
"coerce": {:hex, :coerce, "1.0.1", "211c27386315dc2894ac11bc1f413a0e38505d808153367bd5c6e75a4003d096", [:mix], [], "hexpm", "b44a691700f7a1a15b4b7e2ff1fa30bebd669929ac8aa43cffe9e2f8bf051cf1"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"credo": {:hex, :credo, "1.7.5", "643213503b1c766ec0496d828c90c424471ea54da77c8a168c725686377b9545", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "f799e9b5cd1891577d8c773d245668aa74a2fcd15eb277f51a0131690ebfb3fd"},
"credo": {:hex, :credo, "1.7.7", "771445037228f763f9b2afd612b6aa2fd8e28432a95dbbc60d8e03ce71ba4446", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "8bc87496c9aaacdc3f90f01b7b0582467b69b4bd2441fe8aae3109d843cc2f2e"},
"dialyxir": {:hex, :dialyxir, "1.4.3", "edd0124f358f0b9e95bfe53a9fcf806d615d8f838e2202a9f430d59566b6b53b", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "bf2cfb75cd5c5006bec30141b131663299c661a864ec7fbbc72dfa557487a986"},
"earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"},
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
"erlex": {:hex, :erlex, "0.2.6", "c7987d15e899c7a2f34f5420d2a2ea0d659682c06ac607572df55a43753aa12e", [:mix], [], "hexpm", "2ed2e25711feb44d52b17d2780eabf998452f6efda104877a3881c2f8c0c0c75"},
"evision": {:hex, :evision, "0.1.38", "f8b23ad685c3ebd70969a3457027b5c74b5bc8dc51588661c516098c3240b92d", [:make, :mix, :rebar3], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}, {:progress_bar, "~> 2.0 or ~> 3.0", [hex: :progress_bar, repo: "hexpm", optional: true]}], "hexpm", "f9302547d76c5e4ad7022ffdc76be13e33c990fdd67ad2af203f24ab5d3aee20"},
"ex_doc": {:hex, :ex_doc, "0.32.1", "21e40f939515373bcdc9cffe65f3b3543f05015ac6c3d01d991874129d173420", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "5142c9db521f106d61ff33250f779807ed2a88620e472ac95dc7d59c380113da"},
"exla": {:hex, :exla, "0.7.2", "8ac573093df8e5e6b36845beeb3f5a0ea92b05082bf2fa4678f80170cfc887f6", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d061ea87858415e5585cbd4b7bdae5489000339519a2c6a7f51eb0defd73b588"},
"file_system": {:hex, :file_system, "1.0.0", "b689cc7dcee665f774de94b5a832e578bd7963c8e637ef940cd44327db7de2cd", [:mix], [], "hexpm", "6752092d66aec5a10e662aefeed8ddb9531d79db0bc145bb8c40325ca1d8536d"},
"file_system": {:hex, :file_system, "1.0.1", "79e8ceaddb0416f8b8cd02a0127bdbababe7bf4a23d2a395b983c1f8b3f73edd", [:mix], [], "hexpm", "4414d1f38863ddf9120720cd976fce5bdde8e91d8283353f0e31850fa89feb9e"},
"finch": {:hex, :finch, "0.18.0", "944ac7d34d0bd2ac8998f79f7a811b21d87d911e77a786bc5810adb75632ada4", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "69f5045b042e531e53edc2574f15e25e735b522c37e2ddb766e15b979e03aa65"},
"hpax": {:hex, :hpax, "0.2.0", "5a58219adcb75977b2edce5eb22051de9362f08236220c9e859a47111c194ff5", [:mix], [], "hexpm", "bea06558cdae85bed075e6c036993d43cd54d447f76d8190a8db0dc5893fa2f1"},
"image": {:hex, :image, "0.44.0", "e8eea9398abbed12b7784e786f26a5c839a00bcddd8f2f8ba12adf7e227beb9f", [:mix], [{:bumblebee, "~> 0.3", [hex: :bumblebee, repo: "hexpm", optional: true]}, {:evision, "~> 0.1.33", [hex: :evision, repo: "hexpm", optional: true]}, {:exla, "~> 0.5", [hex: :exla, repo: "hexpm", optional: true]}, {:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: true]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: true]}, {:nx_image, "~> 0.1", [hex: :nx_image, repo: "hexpm", optional: true]}, {:phoenix_html, "~> 2.1 or ~> 3.2 or ~> 4.0", [hex: :phoenix_html, repo: "hexpm", optional: false]}, {:plug, "~> 1.13", [hex: :plug, repo: "hexpm", optional: true]}, {:req, "~> 0.4", [hex: :req, repo: "hexpm", optional: true]}, {:rustler, "> 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:sweet_xml, "~> 0.7", [hex: :sweet_xml, repo: "hexpm", optional: false]}, {:vix, "~> 0.23", [hex: :vix, repo: "hexpm", optional: false]}], "hexpm", "cd00a3de4d7a40a2cb1ca72b9852b0d81701793414af8babf4d33dbeb6de0f6f"},
"jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"},
"jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"},
"makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.5", "e0ff5a7c708dda34311f7522a8758e23bfcd7d8d8068dc312b5eb41c6fd76eba", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "94d2e986428585a21516d7d7149781480013c56e30c6a233534bedf38867a59a"},
Expand Down
30 changes: 0 additions & 30 deletions test/ex_vision/style_transfer/candy_test.exs

This file was deleted.

56 changes: 56 additions & 0 deletions test/ex_vision/style_transfer/style_transfer_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defmodule TestConfiguration do
@spec configuration() :: %{}
def configuration do
%{
ExVision.StyleTransfer.CandyTest => [
module: ExVision.StyleTransfer.Candy,
gt_file: "cat_candy.gt"
],
ExVision.StyleTransfer.CandyFastTest => [
module: ExVision.StyleTransfer.CandyFast,
gt_file: "cat_candy_fast.gt"
],
ExVision.StyleTransfer.PrincessTest => [
module: ExVision.StyleTransfer.Princess,
gt_file: "cat_princess.gt"
],
ExVision.StyleTransfer.PrincessFastTest => [
module: ExVision.StyleTransfer.PrincessFast,
gt_file: "cat_princess_fast.gt"
],
ExVision.StyleTransfer.UdnieTest => [
module: ExVision.StyleTransfer.Udnie,
gt_file: "cat_udnie.gt"
],
ExVision.StyleTransfer.UdnieFastTest => [
module: ExVision.StyleTransfer.UdnieFast,
gt_file: "cat_udnie_fast.gt"
],
ExVision.StyleTransfer.MosaicTest => [
module: ExVision.StyleTransfer.Mosaic,
gt_file: "cat_mosaic.gt"
],
ExVision.StyleTransfer.MosaicFastTest => [
module: ExVision.StyleTransfer.MosaicFast,
gt_file: "cat_mosaic_fast.gt"
]
}
end
end

for {module, opts} <- TestConfiguration.configuration() do
defmodule module do
use ExVision.Model.Case, module: unquote(opts[:module])

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test inference (ExVision.StyleTransfer.PrincessTest)

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test inference (ExVision.StyleTransfer.CandyFastTest)

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test inference (ExVision.StyleTransfer.CandyTest)

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test inference for batch (ExVision.StyleTransfer.CandyFastTest)

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test inference for batch (ExVision.StyleTransfer.PrincessTest)

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test stateful/process workflow inference (ExVision.StyleTransfer.CandyFastTest)

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test inference for batch (ExVision.StyleTransfer.CandyTest)

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test stateful/process workflow inference (ExVision.StyleTransfer.PrincessTest)

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test stateful/process workflow inference for batch (ExVision.StyleTransfer.CandyFastTest)

Check failure on line 43 in test/ex_vision/style_transfer/style_transfer_test.exs

View workflow job for this annotation

GitHub Actions / Build and test

test stateful/process workflow inference (ExVision.StyleTransfer.CandyTest)
use ExVision.TestUtils

@impl true
def test_inference_result(result) do
expected_result =
"test/assets/results/style_transfer/#{unquote(opts[:gt_file])}"
|> File.read!()
|> Nx.deserialize()

assert_tensors_equal(result, expected_result)
end
end
end
5 changes: 4 additions & 1 deletion test/support/exvision/test_utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ defmodule ExVision.TestUtils do

defmacro assert_tensors_equal(a, b, delta \\ @default_delta) do
quote do
assert Nx.all_close(unquote(a), unquote(b), atol: unquote(delta)) |> Nx.reduce_min() |> Nx.to_number() == 1
assert unquote(a)
|> Nx.all_close(unquote(b), atol: unquote(delta))
|> Nx.reduce_min()
|> Nx.to_number() == 1
end
end

Expand Down

0 comments on commit fdb204a

Please sign in to comment.