diff --git a/.github/workflows/elixir.yml b/.github/workflows/elixir.yml index 05ba6bf..d1d07b8 100644 --- a/.github/workflows/elixir.yml +++ b/.github/workflows/elixir.yml @@ -9,6 +9,9 @@ on: permissions: contents: read +env: + MIX_ENV: test + jobs: build: name: Build and test @@ -29,8 +32,6 @@ jobs: path: deps key: ${{ runner.os }}-mix-${{ hashFiles('**/mix.lock') }} restore-keys: ${{ runner.os }}-mix- - - name: Checkout LFS - uses: nschloe/action-cached-lfs-checkout@v1.1.2 - name: Install dependencies run: mix deps.get && mix deps.compile - name: Checks if compiles without warning diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml deleted file mode 100644 index 3d4a34e..0000000 --- a/.github/workflows/release.yml +++ /dev/null @@ -1,21 +0,0 @@ -name: Release package - -on: - push: - tags: - - "v*" - -jobs: - release: - permissions: "write-all" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: nschloe/action-cached-lfs-checkout@v1.1.2 - - uses: "marvinpinto/action-automatic-releases@latest" - with: - repo_token: "${{ secrets.GITHUB_TOKEN }}" - automatic_release_tag: "latest" - prerelease: true - title: "Release ${{ github.ref }}" - files: models/* diff --git a/.gitignore b/.gitignore index 324f6fa..70bb34c 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,4 @@ $RECYCLE.BIN/ *.lnk # End of https://www.gitignore.io/api/c,vim,linux,macos,elixir,windows,visualstudiocode +models/ diff --git a/models/coco_categories.json b/assets/categories/coco_categories.json similarity index 100% rename from models/coco_categories.json rename to assets/categories/coco_categories.json diff --git a/models/coco_with_voc_labels_categories.json b/assets/categories/coco_with_voc_labels_categories.json similarity index 100% rename from models/coco_with_voc_labels_categories.json rename to assets/categories/coco_with_voc_labels_categories.json diff --git a/models/imagenet_v2_categories.json b/assets/categories/imagenet_v2_categories.json similarity index 100% rename from models/imagenet_v2_categories.json rename to assets/categories/imagenet_v2_categories.json diff --git a/config/config.exs b/config/config.exs index b5143cd..a92ca44 100644 --- a/config/config.exs +++ b/config/config.exs @@ -4,7 +4,6 @@ config :nx, default_backend: EXLA.Backend config :logger, level: :debug config :ex_vision, - server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(), - cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache") + server_url: URI.new!("https://ai.swmansion.com/exvision/files") import_config "#{config_env()}.exs" diff --git a/config/dev.exs b/config/dev.exs index b59c744..88b40c5 100644 --- a/config/dev.exs +++ b/config/dev.exs @@ -2,6 +2,4 @@ import Config config :ortex, Ortex.Native, features: ["coreml"] -config :ex_vision, - server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(), - cache_path: System.get_env("EX_VISION_CACHE_DIR", "models") +config :ex_vision, cache_path: "models" diff --git a/config/prod.exs b/config/prod.exs index 239090c..477c907 100644 --- a/config/prod.exs +++ b/config/prod.exs @@ -1,7 +1,3 @@ import Config config :logger, level: :info - -config :ex_vision, - server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(), - cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache") diff --git a/config/runtime.exs b/config/runtime.exs index ec7bb76..b91ec17 100644 --- a/config/runtime.exs +++ b/config/runtime.exs @@ -1,5 +1,8 @@ import Config config :ex_vision, - server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(), + server_url: + "EX_VISION_HOSTING_URI" + |> System.get_env("https://ai.swmansion.com/exvision/files") + |> URI.new!(), cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache") diff --git a/examples/3-membrane.livemd b/examples/3-membrane.livemd index dc92698..4430804 100644 --- a/examples/3-membrane.livemd +++ b/examples/3-membrane.livemd @@ -93,7 +93,6 @@ defmodule Membrane.ExVision.Detector do |> then(&"#{&1}") |> :base64.encode() |> String.to_atom() - |> dbg() {:ok, pid} = Model.start_link(name: name) diff --git a/lib/ex_vision/cache.ex b/lib/ex_vision/cache.ex index 44b9e94..eff2f41 100644 --- a/lib/ex_vision/cache.ex +++ b/lib/ex_vision/cache.ex @@ -1,51 +1,130 @@ defmodule ExVision.Cache do @moduledoc false - # Module responsible for handling model file caching + use GenServer require Logger - @default_cache_path Application.compile_env(:ex_vision, :cache_path, "/tmp/ex_vision/cache") - defp get_cache_path() do - Application.get_env(:ex_vision, :cache_path, @default_cache_path) - end - - @default_server_url Application.compile_env(:ex_vision, :server_url, "http://localhost:8000") - defp get_server_url() do - Application.get_env(:ex_vision, :server_url, @default_server_url) - end - - @type lazy_get_option_t() :: - {:cache_path, Path.t()} | {:server_url, String.t() | URI.t()} | {:force, true} + @type lazy_get_option_t() :: {:force, boolean()} @doc """ Lazily evaluate the path from the cache directory. It will only download the file if it's missing or the `force: true` option is given. """ - @spec lazy_get(Path.t(), options :: [lazy_get_option_t()]) :: + @spec lazy_get(term() | pid(), Path.t(), options :: [lazy_get_option_t()]) :: {:ok, Path.t()} | {:error, reason :: atom()} - def lazy_get(path, options \\ []) do - options = - Keyword.validate!(options, - cache_path: get_cache_path(), - server_url: get_server_url(), - force: false - ) - - cache_path = Path.join(options[:cache_path], path) - ok? = File.exists?(cache_path) - - if ok? and not options[:force] do - Logger.debug("Found existing cache entry for #{path}. Loading.") - {:ok, cache_path} - else - with {:ok, server_url} <- URI.new(options[:server_url]) do - download_url = URI.append_path(server_url, ensure_backslash(path)) - download_file(download_url, cache_path) - end + def lazy_get(server, path, options \\ []) do + with {:ok, options} <- Keyword.validate(options, force: false), + do: GenServer.call(server, {:download, path, options}, :infinity) + end + + @spec start_link(keyword()) :: GenServer.on_start() + def start_link(opts) do + {init_args, opts} = Keyword.split(opts, [:server_url, :cache_path]) + GenServer.start_link(__MODULE__, init_args, opts) + end + + @impl true + def init(opts) do + opts = Keyword.validate!(opts, cache_path: get_cache_path(), server_url: get_server_url()) + + with {:ok, server_url} <- URI.new(opts[:server_url]), + :ok <- File.mkdir_p(opts[:cache_path]) do + {:ok, + %{ + downloads: %{}, + server_url: server_url, + cache_path: opts[:cache_path], + refs: %{} + }} end end + @impl true + def handle_call({:download, cache_path, options}, from, state) do + file_path = Path.join(state.cache_path, cache_path) + + updated_downloads = + Map.update(state.downloads, cache_path, MapSet.new([from]), &MapSet.put(&1, from)) + + cond do + Map.has_key?(state.downloads, cache_path) -> + {:noreply, %{state | downloads: updated_downloads}} + + File.exists?(file_path) or options[:force] -> + {:reply, {:ok, file_path}, state} + + true -> + ref = do_create_download_job(cache_path, state) + + {:noreply, + %{state | downloads: updated_downloads, refs: Map.put(state.refs, ref, cache_path)}} + end + end + + @impl true + def handle_info({ref, result}, state) do + state = emit(result, ref, state) + {:noreply, state} + end + + @impl true + def handle_info({:DOWN, ref, :process, _pid, reason}, state) do + state = + if reason != :normal do + Logger.error("Task #{inspect(ref)} has crashed due to #{inspect(reason)}") + emit({:error, reason}, ref, state) + else + state + end + + {:noreply, state} + end + + @impl true + def handle_info(msg, state) do + Logger.warning("Received an unknown message #{inspect(msg)}. Ignoring") + {:noreply, state} + end + + defp emit(message, ref, state) do + path = state.refs[ref] + + state.downloads + |> Map.get(path, []) + |> Enum.each(fn from -> + GenServer.reply(from, message) + end) + + %{state | refs: Map.delete(state.refs, ref), downloads: Map.delete(state.downloads, path)} + end + + defp do_create_download_job(path, %{server_url: server_url, cache_path: cache_path}) do + target_file_path = Path.join(cache_path, path) + download_url = URI.append_path(server_url, ensure_backslash(path)) + + %Task{ref: ref} = + Task.async(fn -> + download_file(download_url, target_file_path) + end) + + ref + end + + @default_cache_path Application.compile_env(:ex_vision, :cache_path, "/tmp/ex_vision/cache") + defp get_cache_path() do + Application.get_env(:ex_vision, :cache_path, @default_cache_path) + end + + @default_server_url Application.compile_env( + :ex_vision, + :server_url, + URI.new!("https://ai.swmansion.com/exvision/files") + ) + defp get_server_url() do + Application.get_env(:ex_vision, :server_url, @default_server_url) + end + @spec download_file(URI.t(), Path.t()) :: {:ok, Path.t()} | {:error, reason :: any()} defp download_file(url, cache_path) do @@ -59,6 +138,9 @@ defmodule ExVision.Cache do end end + defp ensure_backslash("/" <> _rest = i), do: i + defp ensure_backslash(i), do: "/" <> i + defp validate_download(path) do if File.exists?(path), do: :ok, @@ -73,7 +155,8 @@ defmodule ExVision.Cache do {:ok, _resp} -> :ok - {:error, _reason} = error -> + {:error, reason} = error -> + Logger.error("Failed to download the file due to #{inspect(reason)}") File.rm(target_file_path) error end @@ -100,7 +183,4 @@ defmodule ExVision.Cache do {:error, :connection_failed} end end - - defp ensure_backslash("/" <> _rest = path), do: path - defp ensure_backslash(path), do: "/" <> path end diff --git a/lib/ex_vision/classification/mobilenet_v3_small.ex b/lib/ex_vision/classification/mobilenet_v3_small.ex index 3375f37..5f4358c 100644 --- a/lib/ex_vision/classification/mobilenet_v3_small.ex +++ b/lib/ex_vision/classification/mobilenet_v3_small.ex @@ -6,7 +6,7 @@ defmodule ExVision.Classification.MobileNetV3Small do """ use ExVision.Model.Definition.Ortex, model: "mobilenetv3small-classifier.onnx", - categories: "imagenet_v2_categories.json" + categories: "assets/categories/imagenet_v2_categories.json" require Bunch.Typespec alias ExVision.Utils diff --git a/lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex b/lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex index 162b8e2..0476b15 100644 --- a/lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex +++ b/lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex @@ -4,7 +4,7 @@ defmodule ExVision.Detection.FasterRCNN_ResNet50_FPN do """ use ExVision.Model.Definition.Ortex, model: "fasterrcnn_resnet50_fpn_detector.onnx", - categories: "coco_categories.json" + categories: "assets/categories/coco_categories.json" use ExVision.Detection.GenericDetector diff --git a/lib/ex_vision/detection/ssdlite320_mobilenetv3.ex b/lib/ex_vision/detection/ssdlite320_mobilenetv3.ex index 492522a..043ddb1 100644 --- a/lib/ex_vision/detection/ssdlite320_mobilenetv3.ex +++ b/lib/ex_vision/detection/ssdlite320_mobilenetv3.ex @@ -4,7 +4,7 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3 do """ use ExVision.Model.Definition.Ortex, model: "ssdlite320_mobilenetv3_detector.onnx", - categories: "coco_categories.json" + categories: "assets/categories/coco_categories.json" use ExVision.Detection.GenericDetector diff --git a/lib/ex_vision/ex_vision.ex b/lib/ex_vision/ex_vision.ex new file mode 100644 index 0000000..1340a02 --- /dev/null +++ b/lib/ex_vision/ex_vision.ex @@ -0,0 +1,10 @@ +defmodule ExVision do + @moduledoc false + use Application + + @impl true + def start(_type, _args) do + children = [{ExVision.Cache, name: ExVision.Cache}] + Supervisor.start_link(children, strategy: :one_for_one) + end +end diff --git a/lib/ex_vision/model/definition/ortex.ex b/lib/ex_vision/model/definition/ortex.ex index c42882c..56884c0 100644 --- a/lib/ex_vision/model/definition/ortex.ex +++ b/lib/ex_vision/model/definition/ortex.ex @@ -36,7 +36,7 @@ defmodule ExVision.Model.Definition.Ortex do - `:cache_path` - specifies a caching directory for this model. - `:providers` - a list of desired providers, sorted by preference. Onnx will attempt to use the first available provider. If none of the provided is available, onnx will fallback to `:cpu`. Default: `[:cpu]` - - `:batch_size` - specifies a default batch size for this instance. Default: `1` + - `:batch_size` - specifies a default batch size for this instance. Default: `1`. """ @type load_option_t() :: {:cache_path, Path.t()} @@ -85,13 +85,11 @@ defmodule ExVision.Model.Definition.Ortex do {:ok, ExVision.Model.t()} | {:error, atom()} def load_ortex_model(module, model_path, options) do with {:ok, options} <- - Keyword.validate(options, [ - :cache_path, + Keyword.validate(options, batch_size: 1, providers: [:cpu] - ]), - cache_options = Keyword.take(options, [:cache_path, :file_path]), - {:ok, path} <- ExVision.Cache.lazy_get(model_path, cache_options), + ), + {:ok, path} <- ExVision.Cache.lazy_get(ExVision.Cache, model_path), {:ok, model} <- do_load_model(path, options[:providers]) do output_names = ExVision.Utils.onnx_output_names(model) diff --git a/lib/ex_vision/model/definition/parts/with_categories.ex b/lib/ex_vision/model/definition/parts/with_categories.ex index ae91809..81d552b 100644 --- a/lib/ex_vision/model/definition/parts/with_categories.ex +++ b/lib/ex_vision/model/definition/parts/with_categories.ex @@ -1,24 +1,11 @@ defmodule ExVision.Model.Definition.Parts.WithCategories do @moduledoc false require Logger - alias ExVision.{Cache, Utils} - - defp get_categories(file) do - file - |> Cache.lazy_get() - |> case do - {:ok, file} -> - Utils.load_categories(file) - - error -> - Logger.error("Failed to load categories from #{file} due to #{inspect(error)}") - raise "Failed to load categories from #{file}" - end - end + alias ExVision.Utils defmacro __using__(options) do options = Keyword.validate!(options, [:name, :categories]) - categories = options |> Keyword.fetch!(:categories) |> get_categories() + categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories() spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative() quote do diff --git a/lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex b/lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex index 1bc0a07..c2a6735 100644 --- a/lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex +++ b/lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex @@ -4,7 +4,7 @@ defmodule ExVision.Segmentation.DeepLabV3_MobileNetV3 do """ use ExVision.Model.Definition.Ortex, model: "deeplab_v3_mobilenetv3_segmentation.onnx", - categories: "coco_with_voc_labels_categories.json" + categories: "assets/categories/coco_with_voc_labels_categories.json" @type output_t() :: %{category_t() => Nx.Tensor.t()} diff --git a/mix.exs b/mix.exs index b83db47..d7be525 100644 --- a/mix.exs +++ b/mix.exs @@ -30,6 +30,8 @@ defmodule ExVision.Mixfile do def application do [ + included_applications: [:ex_vision], + mod: {ExVision, []}, extra_applications: [] ] end diff --git a/models/deeplab_v3_mobilenetv3_segmentation.onnx b/models/deeplab_v3_mobilenetv3_segmentation.onnx deleted file mode 100644 index 3b699d3..0000000 --- a/models/deeplab_v3_mobilenetv3_segmentation.onnx +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d6043d375e83b91793acb86af7ef37783405814e4c03fc8e9d02dcb44beb75ea -size 44111008 diff --git a/models/fasterrcnn_resnet50_fpn_detector.onnx b/models/fasterrcnn_resnet50_fpn_detector.onnx deleted file mode 100644 index d817afc..0000000 --- a/models/fasterrcnn_resnet50_fpn_detector.onnx +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:358a3a72bb702a9540cd8482063b899937c5d11da34f45f7272e67b41fc558b3 -size 167514670 diff --git a/models/mobilenetv3small-classifier.onnx b/models/mobilenetv3small-classifier.onnx deleted file mode 100644 index 7cf136c..0000000 --- a/models/mobilenetv3small-classifier.onnx +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ca63401f87d2e29f8c8d9f81942b51c94424bd37465b960e6552c9e8f068abf2 -size 10181063 diff --git a/models/ssdlite320_mobilenetv3_detector.onnx b/models/ssdlite320_mobilenetv3_detector.onnx deleted file mode 100644 index c3b265c..0000000 --- a/models/ssdlite320_mobilenetv3_detector.onnx +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c31e2e989eb12fb631db58e3e6ed77e5a98dc18c201b67923773ff0699c7dea2 -size 14116834 diff --git a/test/ex_vision/cache_test.exs b/test/ex_vision/cache_test.exs index 76530b4..e4a1523 100644 --- a/test/ex_vision/cache_test.exs +++ b/test/ex_vision/cache_test.exs @@ -6,17 +6,14 @@ defmodule ExVision.CacheTest do @moduletag :tmp_dir - setup %{tmp_dir: tmp_dir} do - app_env_override(:server_url, URI.new!("http://mock_server:8000")) - app_env_override(:cache_path, tmp_dir) - end - setup ctx do files = Map.get(ctx, :files, %{ "/test" => rand_string(256) }) + set_mimic_global() + stub(Req, :get, fn %URI{host: "mock_server", port: 8000, path: path}, options -> options = Keyword.validate!(options, [:raw, :into]) @@ -39,34 +36,35 @@ defmodule ExVision.CacheTest do [files: files] end + setup %{tmp_dir: tmp_dir} do + {:ok, _cache} = + Cache.start_link( + name: MyCache, + server_url: URI.new!("http://mock_server:8000"), + cache_path: tmp_dir + ) + + :ok + end + test "Can download the file", ctx do [{path, expected_contents}] = Enum.to_list(ctx.files) expected_path = Path.join(ctx.tmp_dir, path) - assert {:ok, ^expected_path} = Cache.lazy_get(path) + assert {:ok, ^expected_path} = Cache.lazy_get(MyCache, path) verify_download(expected_path, expected_contents) end test "will fail if server is unreachable" do - app_env_override(:server_url, URI.new!("http://localhost:9999")) - assert {:error, :connection_failed} = Cache.lazy_get("/test") - assert {:error, :connection_failed} = Cache.lazy_get("/test") - end + url = "http://localhost:9999" + {:ok, c} = Cache.start_link(server_url: url, name: nil) - test "will fail if we request file that doesn't exist" do - assert {:error, :doesnt_exist} = Cache.lazy_get("/idk") - assert {:error, :doesnt_exist} = Cache.lazy_get("/idk") + assert {:error, :connection_failed} = Cache.lazy_get(c, "/test") + assert {:error, :connection_failed} = Cache.lazy_get(c, "/test") end - defp app_env_override(key, new_value) do - original = Application.fetch_env(:ex_vision, key) - Application.put_env(:ex_vision, key, new_value) - - on_exit(fn -> - case original do - {:ok, value} -> Application.put_env(:ex_vision, key, value) - :error -> Application.delete_env(:ex_vision, key) - end - end) + test "will fail if we request file that doesn't exist" do + assert {:error, :doesnt_exist} = Cache.lazy_get(MyCache, "/idk") + assert {:error, :doesnt_exist} = Cache.lazy_get(MyCache, "/idk") end defp verify_download(path, expected_contents) do diff --git a/test/support/exvision/model/case.ex b/test/support/exvision/model/case.ex index 8d0f837..dbeb13e 100644 --- a/test/support/exvision/model/case.ex +++ b/test/support/exvision/model/case.ex @@ -9,11 +9,11 @@ defmodule ExVision.Model.Case do quote do use ExUnit.Case, async: true - use ExVision.TestUtils.MockCacheServer + # use ExVision.TestUtils.MockCacheServer @behaviour ExVision.Model.Case setup_all do - {:ok, model} = unquote(opts[:module]).load(cache_path: "models") + {:ok, model} = unquote(opts[:module]).load() [model: model] end @@ -34,7 +34,7 @@ defmodule ExVision.Model.Case do end test "child_spec/1" do - assert spec = unquote(opts[:module]).child_spec(cache_path: "models") + assert spec = unquote(opts[:module]).child_spec() end describe "stateful/process workflow" do @@ -44,7 +44,7 @@ defmodule ExVision.Model.Case do {:ok, _supervisor} = Supervisor.start_link( - [unquote(opts[:module]).child_spec(name: name, cache_path: "models")], + [unquote(opts[:module]).child_spec(name: name)], strategy: :one_for_one ) @@ -69,8 +69,7 @@ defmodule ExVision.Model.Case do name: __MODULE__.TestProcess1, batch_size: 8, batch_timeout: 10, - partitions: true, - cache_path: "models" + partitions: true ] child_spec = {unquote(opts[:module]), options} diff --git a/test/support/exvision/test_utils/mock_cache_server.ex b/test/support/exvision/test_utils/mock_cache_server.ex deleted file mode 100644 index 14ee1f1..0000000 --- a/test/support/exvision/test_utils/mock_cache_server.ex +++ /dev/null @@ -1,24 +0,0 @@ -defmodule ExVision.TestUtils.MockCacheServer do - @moduledoc false - - # It will add a setup step that will mock all calls to Req, eliminating the need to host the files during testing - - defmacro __using__(_opts) do - quote do - use Mimic - - setup_all do - stub(Req, :get, fn - %URI{path: path}, _options -> - file = Path.join("models", path) - - if File.exists?(file), - do: {:ok, %Req.Response{status: 200, body: File.read!(path)}}, - else: {:ok, %Req.Response{status: 404}} - end) - - :ok - end - end - end -end