diff --git a/.credo.exs b/.credo.exs index 5edf2a5..a33db58 100644 --- a/.credo.exs +++ b/.credo.exs @@ -161,9 +161,9 @@ priority: :normal, order: ~w/shortdoc moduledoc behaviour use import require alias/a}, {Credo.Check.Consistency.MultiAliasImportRequireUse, false}, {Credo.Check.Consistency.UnusedVariableNames, force: :meaningful}, - {Credo.Check.Design.DuplicatedCode, false}, + {Credo.Check.Design.DuplicatedCode, []}, {Credo.Check.Readability.AliasAs, false}, - {Credo.Check.Readability.MultiAlias, false}, + {Credo.Check.Readability.MultiAlias, []}, {Credo.Check.Readability.Specs, []}, {Credo.Check.Readability.SinglePipe, false}, {Credo.Check.Readability.WithCustomTaggedTuple, false}, @@ -172,10 +172,10 @@ {Credo.Check.Refactor.DoubleBooleanNegation, false}, {Credo.Check.Refactor.ModuleDependencies, false}, {Credo.Check.Refactor.NegatedIsNil, false}, - {Credo.Check.Refactor.PipeChainStart, false}, + {Credo.Check.Refactor.PipeChainStart, []}, {Credo.Check.Refactor.VariableRebinding, false}, - {Credo.Check.Warning.LeakyEnvironment, false}, - {Credo.Check.Warning.MapGetUnsafePass, false}, + {Credo.Check.Warning.LeakyEnvironment, []}, + {Credo.Check.Warning.MapGetUnsafePass, []}, {Credo.Check.Warning.UnsafeToAtom, false} # diff --git a/README.md b/README.md index b8582f6..5e5679d 100644 --- a/README.md +++ b/README.md @@ -25,11 +25,38 @@ ExVision will take care of all necessary input transformations and covert output MobileNetV3.run(model, "example/files/cat.jpg") #=> %{cat: 0.98, dog: 0.01, car: 0.00, ...} ``` -ExVision is also capable of accepting tensors on input: +ExVision is also capable of accepting tensors and images on input: ```elixir -cat = "example/files/cat.jpg" |> StbImage.read_file!() |> StbImage.to_nx() +cat = Image.open!("example/files/cat.jpg") +{:ok, cat_tensor} = Image.to_nx(cat) MobileNetV3.run(model, cat) #=> %{cat: 0.98, dog: 0.01, car: 0.00, ...} +MobileNetV3.run(model, cat_tensor) #=> %{cat: 0.98, dog: 0.01, car: 0.00, ...} +``` + +### Usage in process workflow + +All ExVision models are implemented using `Nx.Serving`. +They are therefore compatible with process workflow. + +You can start a model's process: + +```elixir +{:ok, pid} = MobileNetV3.start_link(name: MyModel) +``` + +or start it under the supervision tree + +```elixir +{:ok, _supervisor_pid} = Supervisor.start_link([ + {MobileNetV3, name: MyModel} +], strategy: :one_for_one) +``` + +After starting, it's immediatelly available for inference using `batched_run/2` function. + +```elixir +MobileNetV3.batched_run(MyModel, cat) #=> %{cat: 0.98, dog: 0.01, car: 0.00, ...} ``` ## Installation @@ -55,12 +82,13 @@ If the model that you would like to use is missing, feel free to open the issue, - [x] MobileNetV3 Small - [ ] EfficientNetV2 - [ ] SqueezeNet -- [ ] Object detection +- [x] Object detection - [x] SSDLite320 - MobileNetV3 Large backbone + - [x] FasterRCNN ResNet50 FPN - [x] Semantic segmentation - [x] DeepLabV3 - MobileNetV3 - [ ] Instance segmentation - - [x] Mask R-CNN + - [ ] Mask R-CNN - [ ] Keypoint Detection - [ ] Keypoint R-CNN diff --git a/examples/3-membrane.livemd b/examples/3-membrane.livemd index 91995c2..dc92698 100644 --- a/examples/3-membrane.livemd +++ b/examples/3-membrane.livemd @@ -19,8 +19,7 @@ Mix.install( ], config: [ nx: [default_backend: EXLA.Backend] - ], - force: true + ] ) ``` @@ -55,7 +54,7 @@ But before we dive into the code, here are a few tips that will make it both eas defmodule Membrane.ExVision.Detector do use Membrane.Filter - alias ExVision.Detection.Ssdlite320_MobileNetv3, as: Model + alias ExVision.Detection.FasterRCNN_ResNet50_FPN, as: Model alias ExVision.Types.BBox # Define both input and output pads diff --git a/lib/ex_vision/cache.ex b/lib/ex_vision/cache.ex index 005a3b8..44b9e94 100644 --- a/lib/ex_vision/cache.ex +++ b/lib/ex_vision/cache.ex @@ -50,14 +50,21 @@ defmodule ExVision.Cache do {:ok, Path.t()} | {:error, reason :: any()} defp download_file(url, cache_path) do with :ok <- cache_path |> Path.dirname() |> File.mkdir_p(), - target_file = File.stream!(cache_path), - :ok <- do_download_file(url, target_file) do - if File.exists?(cache_path), - do: {:ok, cache_path}, - else: {:error, :download_failed} + tmp_file_path = cache_path <> ".unconfirmed", + tmp_file = File.stream!(tmp_file_path), + :ok <- do_download_file(url, tmp_file), + :ok <- validate_download(tmp_file_path), + :ok <- File.rename(tmp_file_path, cache_path) do + {:ok, cache_path} end end + defp validate_download(path) do + if File.exists?(path), + do: :ok, + else: {:error, :download_failed} + end + @spec do_download_file(URI.t(), File.Stream.t()) :: :ok | {:error, reason :: any()} defp do_download_file(%URI{} = url, %File.Stream{path: target_file_path} = target_file) do Logger.debug("Downloading file from `#{url}` and saving to `#{target_file_path}`") diff --git a/lib/ex_vision/classification/mobilenet_v3_small.ex b/lib/ex_vision/classification/mobilenet_v3_small.ex index a715290..7cd8cba 100644 --- a/lib/ex_vision/classification/mobilenet_v3_small.ex +++ b/lib/ex_vision/classification/mobilenet_v3_small.ex @@ -26,7 +26,7 @@ defmodule ExVision.Classification.MobileNetV3Small do end @impl true - def postprocessing({scores}, _metadata) do + def postprocessing(%{"output" => scores}, _metadata) do scores |> Nx.backend_transfer() |> Nx.flatten() diff --git a/lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex b/lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex new file mode 100644 index 0000000..1819884 --- /dev/null +++ b/lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex @@ -0,0 +1,22 @@ +defmodule ExVision.Detection.FasterRCNN_ResNet50_FPN do + @moduledoc """ + FasterRCNN object detector with ResNet50 backbone and FPN detection head, exported from torchvision. + """ + use ExVision.Model.Definition.Ortex, base_dir: "detection/fasterrcnn_resnet50_fpn" + use ExVision.Detection.GenericDetector + + require Logger + + @impl true + def load(options \\ []) do + if Keyword.has_key?(options, :batch_size) do + Logger.warning( + "`:max_batch_size` was given, but this model can only process batch of size 1. Overriding" + ) + end + + options + |> Keyword.put(:batch_size, 1) + |> default_model_load() + end +end diff --git a/lib/ex_vision/detection/generic_detector.ex b/lib/ex_vision/detection/generic_detector.ex new file mode 100644 index 0000000..8676ecb --- /dev/null +++ b/lib/ex_vision/detection/generic_detector.ex @@ -0,0 +1,77 @@ +defmodule ExVision.Detection.GenericDetector do + @moduledoc false + + # Contains a default implementation of pre and post processing for TorchVision detectors + # To use: `use ExVision.Detection.GenericDetector` + + require Logger + + alias ExVision.Types.{BBox, ImageMetadata} + + @typep output_t() :: [BBox.t()] + + @spec preprocessing(Nx.Tensor.t(), ImageMetadata.t()) :: Nx.Tensor.t() + def preprocessing(img, _metadata) do + ExVision.Utils.resize(img, {224, 224}) + end + + @spec postprocessing({Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()}, ImageMetadata.t(), [atom()]) :: + output_t() + def postprocessing( + %{"boxes" => bboxes, "scores" => scores, "labels" => labels}, + metadata, + categories + ) do + {h, w} = metadata.original_size + scale_x = w / 224 + scale_y = h / 224 + + bboxes = + bboxes + |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) + |> Nx.round() + |> Nx.as_type(:s64) + |> Nx.to_list() + + scores = scores |> Nx.to_list() + labels = labels |> Nx.to_list() + + [bboxes, scores, labels] + |> Enum.zip() + |> Enum.filter(fn {_bbox, score, _label} -> score > 0.1 end) + |> Enum.map(fn {[x1, y1, x2, y2], score, label} -> + %BBox{ + x1: x1, + x2: x2, + y1: y1, + y2: y2, + score: score, + label: Enum.at(categories, label) + } + end) + end + + defmacro __using__(_opts) do + quote do + @typedoc """ + A type describing output of `run/2` as a list of a bounding boxes. + + Each bounding box describes the location of the object indicated by the `label`. + It also provides the `score` field marking the probability of the prediction. + Bounding boxes with very low scores should most likely be ignored. + """ + @type output_t() :: [BBox.t()] + + @impl true + defdelegate preprocessing(image, metadata), to: ExVision.Detection.GenericDetector + + @impl true + @spec postprocessing(tuple(), ExVision.Types.ImageMetadata.t()) :: output_t() + def postprocessing(output, metadata) do + ExVision.Detection.GenericDetector.postprocessing(output, metadata, categories()) + end + + defoverridable preprocessing: 2, postprocessing: 2 + end + end +end diff --git a/lib/ex_vision/detection/ssdlite320_mobilenetv3.ex b/lib/ex_vision/detection/ssdlite320_mobilenetv3.ex index 6d6c1d9..0bfdc43 100644 --- a/lib/ex_vision/detection/ssdlite320_mobilenetv3.ex +++ b/lib/ex_vision/detection/ssdlite320_mobilenetv3.ex @@ -3,20 +3,10 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3 do SSDLite320 object detector with MobileNetV3 Large architecture, exported from torchvision. """ use ExVision.Model.Definition.Ortex, base_dir: "detection/ssdlite320_mobilenetv3" + use ExVision.Detection.GenericDetector require Logger - alias ExVision.Types.BBox - - @typedoc """ - A type describing output of `run/2` as a list of a bounding boxes. - - Each bounding box describes the location of the object indicated by the `label`. - It also provides the `score` field marking the probability of the prediction. - Bounding boxes with very low scores should most likely be ignored. - """ - @type output_t() :: [BBox.t(category_t())] - @impl true def load(options \\ []) do if Keyword.has_key?(options, :batch_size) do @@ -29,40 +19,4 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3 do |> Keyword.put(:batch_size, 1) |> default_model_load() end - - @impl true - def preprocessing(img, _metadata) do - ExVision.Utils.resize(img, {224, 224}) - end - - @impl true - def postprocessing({bboxes, scores, labels}, metadata) do - {h, w} = metadata.original_size - scale_x = w / 224 - scale_y = h / 224 - - bboxes = - bboxes - |> Nx.multiply(Nx.tensor([scale_x, scale_y, scale_x, scale_y])) - |> Nx.round() - |> Nx.as_type(:s64) - |> Nx.to_list() - - scores = scores |> Nx.to_list() - labels = labels |> Nx.to_list() - - [bboxes, scores, labels] - |> Enum.zip() - |> Enum.filter(fn {_bbox, score, _label} -> score > 0.1 end) - |> Enum.map(fn {[x1, y1, x2, y2], score, label} -> - %BBox{ - x1: x1, - x2: x2, - y1: y1, - y2: y2, - score: score, - label: Enum.at(categories(), label) - } - end) - end end diff --git a/lib/ex_vision/model/definition/ortex.ex b/lib/ex_vision/model/definition/ortex.ex index 8258845..7d36cc5 100644 --- a/lib/ex_vision/model/definition/ortex.ex +++ b/lib/ex_vision/model/definition/ortex.ex @@ -67,11 +67,11 @@ defmodule ExVision.Model.Definition.Ortex do end end - defmacrop get_client_postprocessing(module) do + defmacrop get_client_postprocessing(module, output_names) do quote do fn {result, _server_metadata}, metadata -> result - |> split_onnx_result() + |> split_onnx_result(unquote(output_names)) |> Enum.zip(metadata) |> Enum.map(fn {result, metadata} -> unquote(module).postprocessing(result, metadata) end) end @@ -93,11 +93,13 @@ defmodule ExVision.Model.Definition.Ortex do cache_options = Keyword.take(options, [:cache_path, :file_path]), {:ok, path} <- ExVision.Cache.lazy_get(model_path, cache_options), {:ok, model} <- do_load_model(path, options[:providers]) do + output_names = ExVision.Utils.onnx_output_names(model) + model |> then(&Nx.Serving.new(Ortex.Serving, &1)) |> Nx.Serving.batch_size(options[:batch_size]) |> Nx.Serving.client_preprocessing(get_client_preprocessing(module)) - |> Nx.Serving.client_postprocessing(get_client_postprocessing(module)) + |> Nx.Serving.client_postprocessing(get_client_postprocessing(module, output_names)) |> then(&{:ok, struct!(module, serving: &1)}) end end @@ -113,13 +115,17 @@ defmodule ExVision.Model.Definition.Ortex do end end - defp split_onnx_result(tuple) do + defp split_onnx_result(tuple, outputs) do tuple |> Tuple.to_list() |> Enum.map(fn x -> + # Do a backend transfer and also return a list of batches here x |> Nx.backend_transfer() |> Nx.to_batched(1) end) |> Enum.zip() + |> Enum.map(fn parts -> + parts |> Tuple.to_list() |> then(&Enum.zip(outputs, &1)) |> Enum.into(%{}) + end) end @type using_option_t() :: {:base_dir, Path.t()} | {:name, String.t()} diff --git a/lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex b/lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex index 339fe0d..4f250dd 100644 --- a/lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex +++ b/lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex @@ -12,7 +12,7 @@ defmodule ExVision.Segmentation.DeepLabV3_MobileNetV3 do end @impl true - def postprocessing({out, _aux}, metadata) do + def postprocessing(%{"output" => out}, metadata) do cls_per_pixel = out |> Nx.backend_transfer() diff --git a/lib/ex_vision/types/bbox.ex b/lib/ex_vision/types/bbox.ex index 6d09e13..498a835 100644 --- a/lib/ex_vision/types/bbox.ex +++ b/lib/ex_vision/types/bbox.ex @@ -28,7 +28,10 @@ defmodule ExVision.Types.BBox do score: number() } - @typep t() :: t(term()) + @typedoc """ + Exactly like `t:t/1`, but doesn't put any constraints on the `label` field: + """ + @type t() :: t(term()) @doc """ Return the width of the bounding box diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index b1e0df4..15e3482 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -131,16 +131,24 @@ defmodule ExVision.Utils do def onnx_result_backend_transfer(tuple), do: tuple |> Tuple.to_list() |> Enum.map(&Nx.backend_transfer/1) |> List.to_tuple() - @spec onnx_input_shape(struct()) :: tuple() + @spec onnx_input_shape(Ortex.Model.t()) :: tuple() def onnx_input_shape(%Ortex.Model{reference: r}) do ["input", "Float32", shape] = - Ortex.Native.show_session(r) + r + |> Ortex.Native.show_session() |> Enum.find(fn [name, _type, _shape] -> name == "input" end) |> hd() List.to_tuple(shape) end + @spec onnx_output_names(Ortex.Model.t()) :: [String.t()] + def onnx_output_names(%Ortex.Model{reference: r}) do + {_inputs, outputs} = Ortex.Native.show_session(r) + + Enum.map(outputs, fn {name, _type, _shape} -> name end) + end + defn softmax(x) do Nx.divide(Nx.exp(x), Nx.sum(Nx.exp(x))) end diff --git a/mix.exs b/mix.exs index 972e15b..b83db47 100644 --- a/mix.exs +++ b/mix.exs @@ -95,7 +95,8 @@ defmodule ExVision.Mixfile do Models: [ ExVision.Classification.MobileNetV3Small, ExVision.Segmentation.DeepLabV3_MobileNetV3, - ExVision.Detection.Ssdlite320_MobileNetv3 + ExVision.Detection.Ssdlite320_MobileNetv3, + ExVision.Detection.FasterRCNN_ResNet50_FPN ], Types: [ ExVision.Types, diff --git a/models/detection/fasterrcnn_resnet50_fpn/categories.json b/models/detection/fasterrcnn_resnet50_fpn/categories.json new file mode 100644 index 0000000..de76af3 --- /dev/null +++ b/models/detection/fasterrcnn_resnet50_fpn/categories.json @@ -0,0 +1 @@ +["__background__", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "N/A", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "N/A", "backpack", "umbrella", "N/A", "N/A", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "N/A", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "N/A", "dining table", "N/A", "N/A", "toilet", "N/A", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "N/A", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] \ No newline at end of file diff --git a/models/detection/fasterrcnn_resnet50_fpn/model.onnx b/models/detection/fasterrcnn_resnet50_fpn/model.onnx new file mode 100644 index 0000000..d817afc --- /dev/null +++ b/models/detection/fasterrcnn_resnet50_fpn/model.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:358a3a72bb702a9540cd8482063b899937c5d11da34f45f7272e67b41fc558b3 +size 167514670 diff --git a/models/detection/ssdlite320_mobilenetv3/model.onnx b/models/detection/ssdlite320_mobilenetv3/model.onnx index 887cc8e..c3b265c 100644 --- a/models/detection/ssdlite320_mobilenetv3/model.onnx +++ b/models/detection/ssdlite320_mobilenetv3/model.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:07b36beb258ada9c3d6bf97080fda6abadad5fa2cdc04c1681c577acf94d2358 -size 14116829 +oid sha256:c31e2e989eb12fb631db58e3e6ed77e5a98dc18c201b67923773ff0699c7dea2 +size 14116834 diff --git a/models/index.json b/models/index.json deleted file mode 100644 index 13ba29c..0000000 --- a/models/index.json +++ /dev/null @@ -1,8 +0,0 @@ -[ - { - "slug": "segmentation_deeplab_v3", - "version": 0, - "name": "DeepLab V3", - "url": "/segmentation/deeplab_v3" - } -] diff --git a/python/exports/ssdlite320_mobilenetv3.py b/python/exports/detection/fasterrcnn_resnet50_fpn.py similarity index 61% rename from python/exports/ssdlite320_mobilenetv3.py rename to python/exports/detection/fasterrcnn_resnet50_fpn.py index aa8e6d6..65fa5e9 100644 --- a/python/exports/ssdlite320_mobilenetv3.py +++ b/python/exports/detection/fasterrcnn_resnet50_fpn.py @@ -1,19 +1,19 @@ from torchvision.models.detection import ( - fasterrcnn_mobilenet_v3_large_fpn, - FasterRCNN_MobileNet_V3_Large_FPN_Weights, + fasterrcnn_resnet50_fpn, + FasterRCNN_ResNet50_FPN_Weights, ) import torch import json from pathlib import Path -base_dir = Path("models/detection/ssdlite320_mobilenetv3") +base_dir = Path("models/detection/fasterrcnn_resnet50_fpn") base_dir.mkdir(parents=True, exist_ok=True) model_file = base_dir / "model.onnx" categories_file = base_dir / "categories.json" -weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT -model = fasterrcnn_mobilenet_v3_large_fpn(weights=weights) +weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT +model = fasterrcnn_resnet50_fpn(weights=weights) model.eval() categories = weights.meta["categories"] @@ -23,6 +23,7 @@ json.dump(categories, f) onnx_input = torch.rand(1, 3, 224, 224) +onnx_input = transforms(onnx_input) torch.onnx.export( model, @@ -30,20 +31,13 @@ str(model_file), verbose=False, input_names=["input"], - output_names=["output", "scores", "labels"], - dynamic_axes={ - "input": {0: "batch_size"}, - "output": {0: "batch_size"}, - "scores": {0: "batch_size"}, - "labels": {0: "batch_size"}, - }, + output_names=["boxes", "labels", "scores"], + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, export_params=True, ) import onnxruntime as onnxrt -onnx_input = torch.rand(1, 3, 224, 224) - sesh = onnxrt.InferenceSession(str(model_file)) inputs = {sesh.get_inputs()[0].name: onnx_input.numpy()} outputs = [x.name for x in sesh.get_outputs()] @@ -51,3 +45,6 @@ output = sesh.run(outputs, inputs) print(len(output)) print(len(output[0])) +print(output) + +print(model(onnx_input)) diff --git a/python/exports/detection/ssdlite320_mobilenetv3.py b/python/exports/detection/ssdlite320_mobilenetv3.py new file mode 100644 index 0000000..4ec86ce --- /dev/null +++ b/python/exports/detection/ssdlite320_mobilenetv3.py @@ -0,0 +1,37 @@ +from torchvision.models.detection import ( + ssdlite320_mobilenet_v3_large, + SSDLite320_MobileNet_V3_Large_Weights, +) +import torch +import json +from pathlib import Path + +base_dir = Path("models/detection/ssdlite320_mobilenetv3") +base_dir.mkdir(parents=True, exist_ok=True) + +model_file = base_dir / "model.onnx" +categories_file = base_dir / "categories.json" + +weights = SSDLite320_MobileNet_V3_Large_Weights.DEFAULT +model = ssdlite320_mobilenet_v3_large(weights=weights) +model.eval() + +categories = weights.meta["categories"] +transforms = weights.transforms() + +with open(categories_file, "w") as f: + json.dump(categories, f) + +onnx_input = torch.rand(1, 3, 224, 224) +onnx_input = transforms(onnx_input) + +torch.onnx.export( + model, + onnx_input, + str(model_file), + verbose=False, + input_names=["input"], + output_names=["boxes", "scores", "labels"], + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + export_params=True, +) diff --git a/test/ex_vision/detection/fasterrcnn_resnet50_fpn_test.exs b/test/ex_vision/detection/fasterrcnn_resnet50_fpn_test.exs new file mode 100644 index 0000000..86e09fe --- /dev/null +++ b/test/ex_vision/detection/fasterrcnn_resnet50_fpn_test.exs @@ -0,0 +1,11 @@ +defmodule ExVision.Detection.FasterRCNN_ResNet50_FPN_Test do + use ExVision.Model.Case, module: ExVision.Detection.FasterRCNN_ResNet50_FPN + use ExVision.TestUtils + alias ExVision.Types.BBox + + @impl true + def test_inference_result(result) do + assert [%BBox{x1: 135, y1: 22, label: :cat, score: score}] = result + assert_floats_equal(score, 1.0) + end +end diff --git a/test/ex_vision/detection/ssdlite320_mobilenetv3_test.exs b/test/ex_vision/detection/ssdlite320_mobilenetv3_test.exs index 140f60e..3af51d3 100644 --- a/test/ex_vision/detection/ssdlite320_mobilenetv3_test.exs +++ b/test/ex_vision/detection/ssdlite320_mobilenetv3_test.exs @@ -1,11 +1,12 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3Test do use ExVision.Model.Case, module: ExVision.Detection.Ssdlite320_MobileNetv3 + use ExVision.TestUtils alias ExVision.Types.BBox @impl true def test_inference_result(result) do assert [%BBox{x1: 132, y1: 12, label: :cat, score: score}] = result - assert score > 0.95 + assert_floats_equal(score, 1.0) end end diff --git a/test/ex_vision/utils_test.exs b/test/ex_vision/utils_test.exs index 01eb544..b1c5c65 100644 --- a/test/ex_vision/utils_test.exs +++ b/test/ex_vision/utils_test.exs @@ -219,12 +219,12 @@ defmodule ExVision.UtilsTest do describe "convert_channel_spec/2" do test "converts :last to :first" do input = Nx.iota({1, 2, 3}) - assert Utils.convert_channel_spec(input, :first) |> Nx.shape() == {3, 1, 2} + assert input |> Utils.convert_channel_spec(:first) |> Nx.shape() == {3, 1, 2} end test "converts :first to :last" do input = Nx.iota({3, 1, 2}) - assert Utils.convert_channel_spec(input, :last) |> Nx.shape() == {1, 2, 3} + assert input |> Utils.convert_channel_spec(:last) |> Nx.shape() == {1, 2, 3} end end end diff --git a/test/support/exvision/model/case.ex b/test/support/exvision/model/case.ex index cc1fd22..8d0f837 100644 --- a/test/support/exvision/model/case.ex +++ b/test/support/exvision/model/case.ex @@ -43,7 +43,8 @@ defmodule ExVision.Model.Case do model = ctx[:model] {:ok, _supervisor} = - Supervisor.start_link([unquote(opts[:module]).child_spec(name: name, cache_path: "models")], + Supervisor.start_link( + [unquote(opts[:module]).child_spec(name: name, cache_path: "models")], strategy: :one_for_one )