From 97a0659ae00d4f22ddfb6e6016f1e9abb9315166 Mon Sep 17 00:00:00 2001 From: Mateusz Kopcinski Date: Wed, 14 Aug 2024 10:24:41 +0200 Subject: [PATCH] added style transfer models --- lib/ex_vision/style_transfer/candy.ex | 57 ------------------- .../style_transfer/style_transfer.ex | 53 +++++++++++++++++ 2 files changed, 53 insertions(+), 57 deletions(-) delete mode 100644 lib/ex_vision/style_transfer/candy.ex create mode 100644 lib/ex_vision/style_transfer/style_transfer.ex diff --git a/lib/ex_vision/style_transfer/candy.ex b/lib/ex_vision/style_transfer/candy.ex deleted file mode 100644 index 71baf7d..0000000 --- a/lib/ex_vision/style_transfer/candy.ex +++ /dev/null @@ -1,57 +0,0 @@ -defmodule ExVision.StyleTransfer.Candy do - @moduledoc """ - An instance segmentation model with a ResNet-50-FPN backbone. Exported from torchvision. - """ - use ExVision.Model.Definition.Ortex, - # model: "udnie.onnx", - model: "udnie.onnx", - categories: "priv/categories/coco_categories.json" - - import ExVision.Utils - - require Logger - - alias ExVision.Types.BBoxWithMask - - @type output_t() :: [BBoxWithMask.t()] - - @impl true - def load(options \\ []) do - if Keyword.has_key?(options, :batch_size) do - Logger.warning( - "`:max_batch_size` was given, but this model can only process batch of size 1. Overriding" - ) - end - - options - |> Keyword.put(:batch_size, 1) - |> default_model_load() - end - - @impl true - def preprocessing(img, _metdata) do - # {ExVision.Utils.resize(img, {640, 480}), Nx.tensor([1.0,1.0,1.0,1.0], type: :f32)} - Logger.info(ExVision.Utils.resize(img, {640, 480})) - - ExVision.Utils.resize(img, {640, 480}) |> Nx.divide(255.0) - end - - @impl true - def postprocessing( - stylized_frame, - metadata - ) do - categories = categories() - - {h, w} = metadata.original_size - scale_x = w / 640 - scale_y = h / 480 - - stylized_frame - end - defp clamp(tensor) do - tensor - |> Nx.max(0.0) - |> Nx.min(255.0) - end - end diff --git a/lib/ex_vision/style_transfer/style_transfer.ex b/lib/ex_vision/style_transfer/style_transfer.ex new file mode 100644 index 0000000..51e6105 --- /dev/null +++ b/lib/ex_vision/style_transfer/style_transfer.ex @@ -0,0 +1,53 @@ +defmodule Configuration do + @low_resolution {400,300} + @high_resolution {400,300} + def configuration do + %{ + ExVision.StyleTransfer.Candy => [model: "candy.onnx", resolution: @high_resolution, categories: "priv/categories/coco_categories.json"], + ExVision.StyleTransfer.CandyFast => [model: "candy_fast.onnx", resolution: @low_resolution, categories: "priv/categories/coco_categories.json"], + ExVision.StyleTransfer.Princess => [model: "princess.onnx", resolution: @high_resolution, categories: "priv/categories/coco_categories.json"], + ExVision.StyleTransfer.PrincessFast => [model: "princess_fast.onnx", resolution: @low_resolution, categories: "priv/categories/coco_categories.json"], + ExVision.StyleTransfer.Udnie => [model: "udnie.onnx", resolution: @high_resolution, categories: "priv/categories/coco_categories.json"], + ExVision.StyleTransfer.UdnieFast => [model: "udnie_fast.onnx", resolution: @low_resolution, categories: "priv/categories/coco_categories.json"], + ExVision.StyleTransfer.Mosaic => [model: "mosaic.onnx", resolution: @high_resolution, categories: "priv/categories/coco_categories.json"], + ExVision.StyleTransfer.MosaicFast => [model: "mosaic_fast.onnx", resolution: @low_resolution, categories: "priv/categories/coco_categories.json"], + } + end +end + +for {module, opts} <- Configuration.configuration() do + defmodule module do + require Logger + @type output_t() :: [Nx.Tensor.t()] + + use ExVision.Model.Definition.Ortex, model: unquote(opts[:model]), categories: "priv/categories/coco_categories.json" + + @impl true + def load(options \\ []) do + if Keyword.has_key?(options, :batch_size) do + Logger.warning( + "`:max_batch_size` was given, but this model can only process batch of size 1. Overriding" + ) + end + + options + |> Keyword.put(:batch_size, 1) + |> default_model_load() + end + + @impl true + def preprocessing(img, _metdata) do + ExVision.Utils.resize(img, unquote(opts[:resolution])) |> Nx.divide(255.0) + end + + @impl true + def postprocessing( + stylized_frame, + metadata + ) do + + stylized_frame = stylized_frame["55"] + NxImage.resize(stylized_frame, metadata.original_size, channels: :first) + end + end +end