Skip to content

Commit

Permalink
Style transfer (#16)
Browse files Browse the repository at this point in the history
Co-authored-by: Mateusz Kopcinski <[email protected]>
  • Loading branch information
mkopcins and Mateusz Kopcinski authored Aug 21, 2024
1 parent ef8f3d4 commit 8b73a2b
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 22 deletions.
12 changes: 6 additions & 6 deletions lib/ex_vision/model/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ defmodule ExVision.Model.Definition do
Application.ensure_all_started(:req)

options =
Keyword.validate!(options, [
:categories,
Keyword.validate!(options,
categories: nil,
name: module_to_name(__CALLER__.module)
])
)

quote do
unless is_nil(unquote(options[:categories])) do
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)
end
# conditional defintion based on whether `categories` option is present has to be moved inside __using__ macro
# here is explenation https://cocoa-research.works/2022/10/conditional-compliation-with-if-and-use-in-elixir/
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)

@behaviour ExVision.Model.Definition

Expand Down
29 changes: 16 additions & 13 deletions lib/ex_vision/model/definition/parts/with_categories.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@ defmodule ExVision.Model.Definition.Parts.WithCategories do

defmacro __using__(options) do
options = Keyword.validate!(options, [:name, :categories])
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
require Bunch.Typespec
unless is_nil(options |> Keyword.fetch!(:categories)) do
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)
quote do
require Bunch.Typespec

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
end
end
end
end
75 changes: 75 additions & 0 deletions lib/ex_vision/style_transfer/style_transfer.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
defmodule Configuration do
@moduledoc false

@low_resolution {400, 300}
@high_resolution {640, 480}

@spec configuration() :: %{}
def configuration do
%{
ExVision.StyleTransfer.Candy => [model: "candy.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.CandyFast => [model: "candy_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Princess => [model: "princess.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.PrincessFast => [
model: "princess_fast.onnx",
resolution: @low_resolution
],
ExVision.StyleTransfer.Udnie => [model: "udnie.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.UdnieFast => [model: "udnie_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Mosaic => [model: "mosaic.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.MosaicFast => [
model: "mosaic_fast.onnx",
resolution: @low_resolution
]
}
end
end

for {module, opts} <- Configuration.configuration() do
defmodule module do
@moduledoc """
#{module} is a custom style transfer model optimised for devices with low computational capabilities and CPU inference.
"""
use ExVision.Model.Definition.Ortex, model: unquote(opts[:model])

require Logger

@typedoc """
A type consisting of output tesnor (stylized image tensor) from style transfer models of shape {#{Enum.join(Tuple.to_list(opts[:resolution]) ++ [3], ", ")}}.
"""
@type output_t() :: Nx.Tensor.t()

@impl true
def load(options \\ []) do
if Keyword.has_key?(options, :batch_size) do
Logger.warning(
"`:max_batch_size` was given, but this model can only process batch of size 1. Overriding"
)
end

options
|> Keyword.put(:batch_size, 1)
|> default_model_load()
end

@impl true
def preprocessing(img, _metdata) do
img |> ExVision.Utils.resize(unquote(opts[:resolution])) |> Nx.divide(255.0)
end

@impl true
def postprocessing(
stylized_frame,
metadata
) do
{h, w} = unquote(opts[:resolution])

stylized_frame["55"]
|> Nx.reshape({3, h, w}, names: [:channel, :height, :width])
|> NxImage.resize(metadata.original_size, channels: :first, method: :bilinear)
|> Nx.clip(0.0, 255.0)
|> Nx.as_type(:u8)
|> Nx.transpose(axes: [1, 2, 0])
end
end
end
9 changes: 9 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ defmodule ExVision.Mixfile do
ExVision.Classification.EfficientNet_V2_L,
ExVision.Classification.SqueezeNet1_1,
ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3,
ExVision.StyleTransfer.Candy,
ExVision.StyleTransfer.CandyFast,
ExVision.StyleTransfer.Udnie,
ExVision.StyleTransfer.UdnieFast,
ExVision.StyleTransfer.Mosaic,
ExVision.StyleTransfer.MosaicFast,
ExVision.StyleTransfer.Princess,
ExVision.StyleTransfer.PrincessFast,
ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2,
ExVision.ObjectDetection.Ssdlite320_MobileNetv3,
ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN,
Expand All @@ -123,6 +131,7 @@ defmodule ExVision.Mixfile do
ExVision.Types,
ExVision.Classification,
ExVision.SemanticSegmentation,
ExVision.StyleTransfer,
ExVision.InstanceSegmentation,
ExVision.ObjectDetection,
ExVision.KeypointDetection
Expand Down
6 changes: 3 additions & 3 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
"cc_precompiler": {:hex, :cc_precompiler, "0.1.10", "47c9c08d8869cf09b41da36538f62bc1abd3e19e41701c2cea2675b53c704258", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f6e046254e53cd6b41c6bacd70ae728011aa82b2742a80d6e2214855c6e06b22"},
"coerce": {:hex, :coerce, "1.0.1", "211c27386315dc2894ac11bc1f413a0e38505d808153367bd5c6e75a4003d096", [:mix], [], "hexpm", "b44a691700f7a1a15b4b7e2ff1fa30bebd669929ac8aa43cffe9e2f8bf051cf1"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"credo": {:hex, :credo, "1.7.5", "643213503b1c766ec0496d828c90c424471ea54da77c8a168c725686377b9545", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "f799e9b5cd1891577d8c773d245668aa74a2fcd15eb277f51a0131690ebfb3fd"},
"credo": {:hex, :credo, "1.7.7", "771445037228f763f9b2afd612b6aa2fd8e28432a95dbbc60d8e03ce71ba4446", [:mix], [{:bunt, "~> 0.2.1 or ~> 1.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2 or ~> 1.0", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "8bc87496c9aaacdc3f90f01b7b0582467b69b4bd2441fe8aae3109d843cc2f2e"},
"dialyxir": {:hex, :dialyxir, "1.4.3", "edd0124f358f0b9e95bfe53a9fcf806d615d8f838e2202a9f430d59566b6b53b", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "bf2cfb75cd5c5006bec30141b131663299c661a864ec7fbbc72dfa557487a986"},
"earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"},
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
"erlex": {:hex, :erlex, "0.2.6", "c7987d15e899c7a2f34f5420d2a2ea0d659682c06ac607572df55a43753aa12e", [:mix], [], "hexpm", "2ed2e25711feb44d52b17d2780eabf998452f6efda104877a3881c2f8c0c0c75"},
"evision": {:hex, :evision, "0.1.38", "f8b23ad685c3ebd70969a3457027b5c74b5bc8dc51588661c516098c3240b92d", [:make, :mix, :rebar3], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}, {:progress_bar, "~> 2.0 or ~> 3.0", [hex: :progress_bar, repo: "hexpm", optional: true]}], "hexpm", "f9302547d76c5e4ad7022ffdc76be13e33c990fdd67ad2af203f24ab5d3aee20"},
"ex_doc": {:hex, :ex_doc, "0.32.1", "21e40f939515373bcdc9cffe65f3b3543f05015ac6c3d01d991874129d173420", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "5142c9db521f106d61ff33250f779807ed2a88620e472ac95dc7d59c380113da"},
"exla": {:hex, :exla, "0.7.2", "8ac573093df8e5e6b36845beeb3f5a0ea92b05082bf2fa4678f80170cfc887f6", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d061ea87858415e5585cbd4b7bdae5489000339519a2c6a7f51eb0defd73b588"},
"file_system": {:hex, :file_system, "1.0.0", "b689cc7dcee665f774de94b5a832e578bd7963c8e637ef940cd44327db7de2cd", [:mix], [], "hexpm", "6752092d66aec5a10e662aefeed8ddb9531d79db0bc145bb8c40325ca1d8536d"},
"file_system": {:hex, :file_system, "1.0.1", "79e8ceaddb0416f8b8cd02a0127bdbababe7bf4a23d2a395b983c1f8b3f73edd", [:mix], [], "hexpm", "4414d1f38863ddf9120720cd976fce5bdde8e91d8283353f0e31850fa89feb9e"},
"finch": {:hex, :finch, "0.18.0", "944ac7d34d0bd2ac8998f79f7a811b21d87d911e77a786bc5810adb75632ada4", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "69f5045b042e531e53edc2574f15e25e735b522c37e2ddb766e15b979e03aa65"},
"hpax": {:hex, :hpax, "0.2.0", "5a58219adcb75977b2edce5eb22051de9362f08236220c9e859a47111c194ff5", [:mix], [], "hexpm", "bea06558cdae85bed075e6c036993d43cd54d447f76d8190a8db0dc5893fa2f1"},
"image": {:hex, :image, "0.44.0", "e8eea9398abbed12b7784e786f26a5c839a00bcddd8f2f8ba12adf7e227beb9f", [:mix], [{:bumblebee, "~> 0.3", [hex: :bumblebee, repo: "hexpm", optional: true]}, {:evision, "~> 0.1.33", [hex: :evision, repo: "hexpm", optional: true]}, {:exla, "~> 0.5", [hex: :exla, repo: "hexpm", optional: true]}, {:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: true]}, {:kino, "~> 0.11", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: true]}, {:nx_image, "~> 0.1", [hex: :nx_image, repo: "hexpm", optional: true]}, {:phoenix_html, "~> 2.1 or ~> 3.2 or ~> 4.0", [hex: :phoenix_html, repo: "hexpm", optional: false]}, {:plug, "~> 1.13", [hex: :plug, repo: "hexpm", optional: true]}, {:req, "~> 0.4", [hex: :req, repo: "hexpm", optional: true]}, {:rustler, "> 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:sweet_xml, "~> 0.7", [hex: :sweet_xml, repo: "hexpm", optional: false]}, {:vix, "~> 0.23", [hex: :vix, repo: "hexpm", optional: false]}], "hexpm", "cd00a3de4d7a40a2cb1ca72b9852b0d81701793414af8babf4d33dbeb6de0f6f"},
"jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"},
"jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"},
"makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.5", "e0ff5a7c708dda34311f7522a8758e23bfcd7d8d8068dc312b5eb41c6fd76eba", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "94d2e986428585a21516d7d7149781480013c56e30c6a233534bedf38867a59a"},
Expand Down
Binary file added test/assets/results/style_transfer/cat_candy.gt
Binary file not shown.
Binary file not shown.
Binary file added test/assets/results/style_transfer/cat_mosaic.gt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added test/assets/results/style_transfer/cat_udnie.gt
Binary file not shown.
Binary file not shown.
56 changes: 56 additions & 0 deletions test/ex_vision/style_transfer/style_transfer_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defmodule TestConfiguration do
@spec configuration() :: %{}
def configuration do
%{
ExVision.StyleTransfer.CandyTest => [
module: ExVision.StyleTransfer.Candy,
gt_file: "cat_candy.gt"
],
ExVision.StyleTransfer.CandyFastTest => [
module: ExVision.StyleTransfer.CandyFast,
gt_file: "cat_candy_fast.gt"
],
ExVision.StyleTransfer.PrincessTest => [
module: ExVision.StyleTransfer.Princess,
gt_file: "cat_princess.gt"
],
ExVision.StyleTransfer.PrincessFastTest => [
module: ExVision.StyleTransfer.PrincessFast,
gt_file: "cat_princess_fast.gt"
],
ExVision.StyleTransfer.UdnieTest => [
module: ExVision.StyleTransfer.Udnie,
gt_file: "cat_udnie.gt"
],
ExVision.StyleTransfer.UdnieFastTest => [
module: ExVision.StyleTransfer.UdnieFast,
gt_file: "cat_udnie_fast.gt"
],
ExVision.StyleTransfer.MosaicTest => [
module: ExVision.StyleTransfer.Mosaic,
gt_file: "cat_mosaic.gt"
],
ExVision.StyleTransfer.MosaicFastTest => [
module: ExVision.StyleTransfer.MosaicFast,
gt_file: "cat_mosaic_fast.gt"
]
}
end
end

for {module, opts} <- TestConfiguration.configuration() do
defmodule module do
use ExVision.Model.Case, module: unquote(opts[:module])
use ExVision.TestUtils

@impl true
def test_inference_result(result) do
expected_result =
"test/assets/results/style_transfer/#{unquote(opts[:gt_file])}"
|> File.read!()
|> Nx.deserialize()

assert_tensors_equal(result, expected_result, 5, 0.05)
end
end
end
22 changes: 22 additions & 0 deletions test/support/exvision/test_utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ defmodule ExVision.TestUtils do
end
end

defmacro assert_tensors_equal(a, b, delta \\ @default_delta, relative_delta \\ 0.0) do
quote do
value_condition =
unquote(a)
|> Nx.all_close(unquote(b), atol: unquote(delta), rtol: unquote(relative_delta))
|> Nx.reduce_min()
|> Nx.to_number() == 1

equal_on_count =
unquote(a)
|> Nx.equal(unquote(b))
|> Nx.as_type(:u64)
|> Nx.reduce(0, fn x, y -> Nx.add(x, y) end)
|> Nx.to_number()

number_count = unquote(a) |> Nx.shape() |> Tuple.product()
proportional_condition = equal_on_count / number_count > 0.99

assert value_condition or proportional_condition
end
end

defmacro __using__(_opts) do
quote do
import ExVision.TestUtils, only: :macros
Expand Down

0 comments on commit 8b73a2b

Please sign in to comment.