diff --git a/lib/ex_vision/model/definition.ex b/lib/ex_vision/model/definition.ex index 4cda716..65f1b70 100644 --- a/lib/ex_vision/model/definition.ex +++ b/lib/ex_vision/model/definition.ex @@ -31,7 +31,6 @@ defmodule ExVision.Model.Definition do ]) quote do - #todo fix unless is_nil(unquote(options[:categories])) do use ExVision.Model.Definition.Parts.WithCategories, unquote(options) end diff --git a/lib/ex_vision/model/definition/ortex.ex b/lib/ex_vision/model/definition/ortex.ex index 532a934..56884c0 100644 --- a/lib/ex_vision/model/definition/ortex.ex +++ b/lib/ex_vision/model/definition/ortex.ex @@ -45,14 +45,7 @@ defmodule ExVision.Model.Definition.Ortex do defmacrop get_client_preprocessing(module) do quote do - # input_preprocessing = fn input -> fn input -> - Logger.info("IO.inspect(input)") - # images = case input do - # {_input, sth} -> ExVision.Utils.load_image(_input) - # _input -> ExVision.Utils.load_image(_input) - # end - images = ExVision.Utils.load_image(input) metadata = @@ -63,22 +56,14 @@ defmodule ExVision.Model.Definition.Ortex do } ) - Logger.info(images) batch = images |> Enum.zip(metadata) |> Enum.map(fn {image, metadata} -> unquote(module).preprocessing(image, metadata) end) |> Nx.Batch.stack() - # batch = batch |> Nx.Batch.stack() - Logger.info(batch) {batch, metadata} end - - # fn - # {input, extra_fields} -> {unquote(input_preprocessing)(input), extra_fields} - # {input} -> unquote(input_preprocessing)(input) - # end end end diff --git a/lib/ex_vision/semantic_segmentation/deep_lab_v3_mobilenet_v3.ex b/lib/ex_vision/semantic_segmentation/deep_lab_v3_mobilenet_v3.ex index dfb232b..4afe39b 100644 --- a/lib/ex_vision/semantic_segmentation/deep_lab_v3_mobilenet_v3.ex +++ b/lib/ex_vision/semantic_segmentation/deep_lab_v3_mobilenet_v3.ex @@ -1,50 +1,32 @@ defmodule ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3 do - @moduledoc """ - An instance segmentation model with a ResNet-50-FPN backbone. Exported from torchvision. - """ - use ExVision.Model.Definition.Ortex, - # model: "udnie.onnx", - model: "udnie.onnx", - categories: "priv/categories/coco_categories.json" - - import ExVision.Utils - - require Logger - - alias ExVision.Types.BBoxWithMask - - @type output_t() :: [BBoxWithMask.t()] - - @impl true - def load(options \\ []) do - if Keyword.has_key?(options, :batch_size) do - Logger.warning( - "`:max_batch_size` was given, but this model can only process batch of size 1. Overriding" - ) - end - - options - |> Keyword.put(:batch_size, 1) - |> default_model_load() - end - - @impl true - def preprocessing(img, _metdata) do - ExVision.Utils.resize(img, {640, 480}) |> Nx.divide(255.0) - end - - @impl true - def postprocessing( - stylized_frame, - metadata - ) do - categories = categories() - - {h, w} = metadata.original_size - scale_x = w / 640 - scale_y = h / 480 - - stylized_frame - end + @moduledoc """ + A semantic segmentation model for MobileNetV3 Backbone. Exported from torchvision. + """ + use ExVision.Model.Definition.Ortex, + model: "deeplab_v3_mobilenetv3_segmentation.onnx", + categories: "priv/categories/coco_with_voc_labels_categories.json" + + @type output_t() :: %{category_t() => Nx.Tensor.t()} + + @impl true + def preprocessing(img, _metdata) do + ExVision.Utils.resize(img, {224, 224}) + end + @impl true + def postprocessing(%{"output" => out}, metadata) do + cls_per_pixel = + out + |> Nx.backend_transfer() + |> NxImage.resize(metadata.original_size, channels: :first) + |> Nx.squeeze() + |> Axon.Activations.softmax(axis: [0]) + |> Nx.argmax(axis: 0) + + categories() + |> Enum.with_index() + |> Map.new(fn {category, i} -> + {category, cls_per_pixel |> Nx.equal(i)} + end) end +end diff --git a/lib/ex_vision/style_transfer/style_transfer.ex b/lib/ex_vision/style_transfer/style_transfer.ex index a37e2f8..ecf9623 100644 --- a/lib/ex_vision/style_transfer/style_transfer.ex +++ b/lib/ex_vision/style_transfer/style_transfer.ex @@ -17,6 +17,9 @@ end for {module, opts} <- Configuration.configuration() do defmodule module do + @moduledoc """ + #{module} is a custom style transfer model optimised for devices with low computational capabilities and CPU inference. + """ require Logger @type output_t() :: [Nx.Tensor.t()] diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index f51d626..f3cd541 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -3,7 +3,6 @@ defmodule ExVision.Utils do require Nx require Image - require Logger alias ExVision.Types @type channel_spec_t() :: :first | :last @@ -150,13 +149,10 @@ defmodule ExVision.Utils do @spec batched_run(atom(), ExVision.Model.input_t()) :: ExVision.Model.output_t() def batched_run(process_name, input) when is_list(input) do - Logger.info("batched_run(process_name, input) when is_list(input) do") - # Nx.Serving.batched_run(process_name, {input, Nx.tensor([1,1,1,1])}) Nx.Serving.batched_run(process_name, input) end def batched_run(process_name, input) do - Logger.info("batched_run(process_name, input) do") process_name |> batched_run([input]) |> hd() end diff --git a/lib/publish_docs_command.ex b/lib/publish_docs_command.ex deleted file mode 100644 index d5ee650..0000000 --- a/lib/publish_docs_command.ex +++ /dev/null @@ -1,14 +0,0 @@ -defmodule Mix.Tasks.PublishDocs do - @moduledoc "The hello mix task: `mix help hello`" - use Mix.Task - require Logger - - @shortdoc "Simply calls the Hello.say/0 function." - def run(_) do - "mix docs" |> String.to_charlist() |> :os.cmd - "find ~+ doc/dist -name sidebar* -print0 | xargs -0 sed -i -E 's/nested_title\":\"\.([a-zA-Z_0-9]*)\"/nested_title\":\"\\\1\"/g'" |> String.to_charlist() |> :os.cmd - Logger.info("replaced prefixes") - "mix hex.publish docs" |> String.to_charlist() |> :os.cmd - Logger.info("published") - end -end diff --git a/mix.exs b/mix.exs index 6932920..ed3fa08 100644 --- a/mix.exs +++ b/mix.exs @@ -100,6 +100,13 @@ defmodule ExVision.Mixfile do ExVision.Classification.SqueezeNet1_1, ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3, ExVision.StyleTransfer.Candy, + ExVision.StyleTransfer.CandyFast, + ExVision.StyleTransfer.Udnie, + ExVision.StyleTransfer.UdnieFast, + ExVision.StyleTransfer.Mosaic, + ExVision.StyleTransfer.MosaicFast, + ExVision.StyleTransfer.Princess, + ExVision.StyleTransfer.PrincessFast, ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2, ExVision.ObjectDetection.Ssdlite320_MobileNetv3, ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN,