Skip to content

Commit

Permalink
cleaned up styletransfer code
Browse files Browse the repository at this point in the history
  • Loading branch information
Mateusz Kopcinski authored and Mateusz Kopcinski committed Aug 14, 2024
1 parent 65be519 commit 57d8d39
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 81 deletions.
1 change: 0 additions & 1 deletion lib/ex_vision/model/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ defmodule ExVision.Model.Definition do
])

quote do
#todo fix
unless is_nil(unquote(options[:categories])) do
use ExVision.Model.Definition.Parts.WithCategories, unquote(options)
end
Expand Down
15 changes: 0 additions & 15 deletions lib/ex_vision/model/definition/ortex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,7 @@ defmodule ExVision.Model.Definition.Ortex do

defmacrop get_client_preprocessing(module) do
quote do
# input_preprocessing = fn input ->
fn input ->
Logger.info("IO.inspect(input)")
# images = case input do
# {_input, sth} -> ExVision.Utils.load_image(_input)
# _input -> ExVision.Utils.load_image(_input)
# end

images = ExVision.Utils.load_image(input)

metadata =
Expand All @@ -63,22 +56,14 @@ defmodule ExVision.Model.Definition.Ortex do
}
)

Logger.info(images)
batch =
images
|> Enum.zip(metadata)
|> Enum.map(fn {image, metadata} -> unquote(module).preprocessing(image, metadata) end)
|> Nx.Batch.stack()

# batch = batch |> Nx.Batch.stack()
Logger.info(batch)
{batch, metadata}
end

# fn
# {input, extra_fields} -> {unquote(input_preprocessing)(input), extra_fields}
# {input} -> unquote(input_preprocessing)(input)
# end
end
end

Expand Down
76 changes: 29 additions & 47 deletions lib/ex_vision/semantic_segmentation/deep_lab_v3_mobilenet_v3.ex
Original file line number Diff line number Diff line change
@@ -1,50 +1,32 @@
defmodule ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3 do
@moduledoc """
An instance segmentation model with a ResNet-50-FPN backbone. Exported from torchvision.
"""
use ExVision.Model.Definition.Ortex,
# model: "udnie.onnx",
model: "udnie.onnx",
categories: "priv/categories/coco_categories.json"

import ExVision.Utils

require Logger

alias ExVision.Types.BBoxWithMask

@type output_t() :: [BBoxWithMask.t()]

@impl true
def load(options \\ []) do
if Keyword.has_key?(options, :batch_size) do
Logger.warning(
"`:max_batch_size` was given, but this model can only process batch of size 1. Overriding"
)
end

options
|> Keyword.put(:batch_size, 1)
|> default_model_load()
end

@impl true
def preprocessing(img, _metdata) do
ExVision.Utils.resize(img, {640, 480}) |> Nx.divide(255.0)
end

@impl true
def postprocessing(
stylized_frame,
metadata
) do
categories = categories()

{h, w} = metadata.original_size
scale_x = w / 640
scale_y = h / 480

stylized_frame
end
@moduledoc """
A semantic segmentation model for MobileNetV3 Backbone. Exported from torchvision.
"""
use ExVision.Model.Definition.Ortex,
model: "deeplab_v3_mobilenetv3_segmentation.onnx",
categories: "priv/categories/coco_with_voc_labels_categories.json"

@type output_t() :: %{category_t() => Nx.Tensor.t()}

@impl true
def preprocessing(img, _metdata) do
ExVision.Utils.resize(img, {224, 224})
end

@impl true
def postprocessing(%{"output" => out}, metadata) do
cls_per_pixel =
out
|> Nx.backend_transfer()
|> NxImage.resize(metadata.original_size, channels: :first)
|> Nx.squeeze()
|> Axon.Activations.softmax(axis: [0])
|> Nx.argmax(axis: 0)

categories()
|> Enum.with_index()
|> Map.new(fn {category, i} ->
{category, cls_per_pixel |> Nx.equal(i)}
end)
end
end
3 changes: 3 additions & 0 deletions lib/ex_vision/style_transfer/style_transfer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ end

for {module, opts} <- Configuration.configuration() do
defmodule module do
@moduledoc """
#{module} is a custom style transfer model optimised for devices with low computational capabilities and CPU inference.
"""
require Logger
@type output_t() :: [Nx.Tensor.t()]

Expand Down
4 changes: 0 additions & 4 deletions lib/ex_vision/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ defmodule ExVision.Utils do

require Nx
require Image
require Logger
alias ExVision.Types

@type channel_spec_t() :: :first | :last
Expand Down Expand Up @@ -150,13 +149,10 @@ defmodule ExVision.Utils do

@spec batched_run(atom(), ExVision.Model.input_t()) :: ExVision.Model.output_t()
def batched_run(process_name, input) when is_list(input) do
Logger.info("batched_run(process_name, input) when is_list(input) do")
# Nx.Serving.batched_run(process_name, {input, Nx.tensor([1,1,1,1])})
Nx.Serving.batched_run(process_name, input)
end

def batched_run(process_name, input) do
Logger.info("batched_run(process_name, input) do")
process_name |> batched_run([input]) |> hd()
end

Expand Down
14 changes: 0 additions & 14 deletions lib/publish_docs_command.ex

This file was deleted.

7 changes: 7 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ defmodule ExVision.Mixfile do
ExVision.Classification.SqueezeNet1_1,
ExVision.SemanticSegmentation.DeepLabV3_MobileNetV3,
ExVision.StyleTransfer.Candy,
ExVision.StyleTransfer.CandyFast,
ExVision.StyleTransfer.Udnie,
ExVision.StyleTransfer.UdnieFast,
ExVision.StyleTransfer.Mosaic,
ExVision.StyleTransfer.MosaicFast,
ExVision.StyleTransfer.Princess,
ExVision.StyleTransfer.PrincessFast,
ExVision.InstanceSegmentation.MaskRCNN_ResNet50_FPN_V2,
ExVision.ObjectDetection.Ssdlite320_MobileNetv3,
ExVision.ObjectDetection.FasterRCNN_ResNet50_FPN,
Expand Down

0 comments on commit 57d8d39

Please sign in to comment.