Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Style transfer #16

Merged
merged 15 commits into from
Aug 21, 2024
8 changes: 4 additions & 4 deletions lib/ex_vision/model/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ defmodule ExVision.Model.Definition do

options =
Keyword.validate!(options, [
:categories,
categories: nil,
name: module_to_name(__CALLER__.module)
])

quote do
unless is_nil(unquote(options[:categories])) do
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)
end
# conditional defintion based on whether `categories` option is present has to be moved inside __using__ macro
# here is explenation https://cocoa-research.works/2022/10/conditional-compliation-with-if-and-use-in-elixir/
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)

@behaviour ExVision.Model.Definition

Expand Down
28 changes: 15 additions & 13 deletions lib/ex_vision/model/definition/parts/with_categories.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@ defmodule ExVision.Model.Definition.Parts.WithCategories do

defmacro __using__(options) do
options = Keyword.validate!(options, [:name, :categories])
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()
unless is_nil(options |> Keyword.fetch!(:categories)) do
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
require Bunch.Typespec
quote do
require Bunch.Typespec

@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)
@typedoc """
Type describing all categories recognised by #{unquote(options[:name])}
"""
@type category_t() :: unquote(spec)

@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
@doc """
Returns a list of all categories recognised by #{unquote(options[:name])}
"""
@spec categories() :: [category_t()]
def categories(), do: unquote(categories)
end
end
end
end
56 changes: 56 additions & 0 deletions lib/ex_vision/style_transfer/style_transfer.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
defmodule Configuration do

Check warning on line 1 in lib/ex_vision/style_transfer/style_transfer.ex

View workflow job for this annotation

GitHub Actions / Build and test

Modules should have a @moduledoc tag.
mkopcins marked this conversation as resolved.
Show resolved Hide resolved
@low_resolution {400,300}

Check warning on line 2 in lib/ex_vision/style_transfer/style_transfer.ex

View workflow job for this annotation

GitHub Actions / Build and test

Space missing after comma.
@high_resolution {640,480}

Check warning on line 3 in lib/ex_vision/style_transfer/style_transfer.ex

View workflow job for this annotation

GitHub Actions / Build and test

Space missing after comma.
def configuration do

Check warning on line 4 in lib/ex_vision/style_transfer/style_transfer.ex

View workflow job for this annotation

GitHub Actions / Build and test

Functions should have a @SPEC type specification.
%{
ExVision.StyleTransfer.Candy => [model: "candy.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.CandyFast => [model: "candy_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Princess => [model: "princess.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.PrincessFast => [model: "princess_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Udnie => [model: "udnie.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.UdnieFast => [model: "udnie_fast.onnx", resolution: @low_resolution],
ExVision.StyleTransfer.Mosaic => [model: "mosaic.onnx", resolution: @high_resolution],
ExVision.StyleTransfer.MosaicFast => [model: "mosaic_fast.onnx", resolution: @low_resolution],
}
end
end

for {module, opts} <- Configuration.configuration() do
defmodule module do
@moduledoc """
#{module} is a custom style transfer model optimised for devices with low computational capabilities and CPU inference.
"""
require Logger
@type output_t() :: [Nx.Tensor.t()]
mkopcins marked this conversation as resolved.
Show resolved Hide resolved

use ExVision.Model.Definition.Ortex, model: unquote(opts[:model])

Check warning on line 26 in lib/ex_vision/style_transfer/style_transfer.ex

View workflow job for this annotation

GitHub Actions / Build and test

use must appear before type

@impl true
def load(options \\ []) do
if Keyword.has_key?(options, :batch_size) do
Logger.warning(
"`:max_batch_size` was given, but this model can only process batch of size 1. Overriding"
)
end

options
|> Keyword.put(:batch_size, 1)
|> default_model_load()
end

@impl true
def preprocessing(img, _metdata) do
ExVision.Utils.resize(img, unquote(opts[:resolution])) |> Nx.divide(255.0)
mkopcins marked this conversation as resolved.
Show resolved Hide resolved
end

@impl true
def postprocessing(
stylized_frame,
metadata
) do

stylized_frame = stylized_frame["55"]
NxImage.resize(stylized_frame, metadata.original_size, channels: :first)
end
end
end
9 changes: 9 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ defmodule ExVision.Mixfile do
ExVision.Classification.EfficientNet_V2_L,
ExVision.Classification.SqueezeNet1_1,
ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3,
ExVision.StyleTransfer.Candy,
ExVision.StyleTransfer.CandyFast,
ExVision.StyleTransfer.Udnie,
ExVision.StyleTransfer.UdnieFast,
ExVision.StyleTransfer.Mosaic,
ExVision.StyleTransfer.MosaicFast,
ExVision.StyleTransfer.Princess,
ExVision.StyleTransfer.PrincessFast,
ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2,
ExVision.ObjectDetection.Ssdlite320_MobileNetv3,
ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN,
Expand All @@ -123,6 +131,7 @@ defmodule ExVision.Mixfile do
ExVision.Types,
ExVision.Classification,
ExVision.SemanticSegmentation,
ExVision.StyleTransfer,
ExVision.InstanceSegmentation,
ExVision.ObjectDetection,
ExVision.KeypointDetection
Expand Down
Loading