Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Host the files #3

Merged
merged 7 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/elixir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on:
permissions:
contents: read

env:
MIX_ENV: test

jobs:
build:
name: Build and test
Expand All @@ -29,8 +32,6 @@ jobs:
path: deps
key: ${{ runner.os }}-mix-${{ hashFiles('**/mix.lock') }}
restore-keys: ${{ runner.os }}-mix-
- name: Checkout LFS
uses: nschloe/[email protected]
- name: Install dependencies
run: mix deps.get && mix deps.compile
- name: Checks if compiles without warning
Expand Down
21 changes: 0 additions & 21 deletions .github/workflows/release.yml

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,4 @@ $RECYCLE.BIN/
*.lnk

# End of https://www.gitignore.io/api/c,vim,linux,macos,elixir,windows,visualstudiocode
models/
File renamed without changes.
3 changes: 1 addition & 2 deletions config/config.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ config :nx, default_backend: EXLA.Backend
config :logger, level: :debug

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache")
server_url: URI.new!("https://ai.swmansion.com/exvision/files")

import_config "#{config_env()}.exs"
4 changes: 1 addition & 3 deletions config/dev.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,4 @@ import Config

config :ortex, Ortex.Native, features: ["coreml"]

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "models")
config :ex_vision, cache_path: "models"
4 changes: 0 additions & 4 deletions config/prod.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import Config

config :logger, level: :info

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache")
5 changes: 4 additions & 1 deletion config/runtime.exs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import Config

config :ex_vision,
server_url: "EX_VISION_HOSTING_URI" |> System.get_env("http://localhost:8000") |> URI.new!(),
server_url:
"EX_VISION_HOSTING_URI"
|> System.get_env("https://ai.swmansion.com/exvision/files")
|> URI.new!(),
cache_path: System.get_env("EX_VISION_CACHE_DIR", "/tmp/ex_vision/cache")
1 change: 0 additions & 1 deletion examples/3-membrane.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ defmodule Membrane.ExVision.Detector do
|> then(&"#{&1}")
|> :base64.encode()
|> String.to_atom()
|> dbg()

{:ok, pid} = Model.start_link(name: name)

Expand Down
154 changes: 117 additions & 37 deletions lib/ex_vision/cache.ex
Original file line number Diff line number Diff line change
@@ -1,51 +1,130 @@
defmodule ExVision.Cache do
@moduledoc false

# Module responsible for handling model file caching

use GenServer
require Logger

@default_cache_path Application.compile_env(:ex_vision, :cache_path, "/tmp/ex_vision/cache")
defp get_cache_path() do
Application.get_env(:ex_vision, :cache_path, @default_cache_path)
end

@default_server_url Application.compile_env(:ex_vision, :server_url, "http://localhost:8000")
defp get_server_url() do
Application.get_env(:ex_vision, :server_url, @default_server_url)
end

@type lazy_get_option_t() ::
{:cache_path, Path.t()} | {:server_url, String.t() | URI.t()} | {:force, true}
@type lazy_get_option_t() :: {:force, boolean()}

@doc """
Lazily evaluate the path from the cache directory.
It will only download the file if it's missing or the `force: true` option is given.
"""
@spec lazy_get(Path.t(), options :: [lazy_get_option_t()]) ::
@spec lazy_get(term() | pid(), Path.t(), options :: [lazy_get_option_t()]) ::
{:ok, Path.t()} | {:error, reason :: atom()}
def lazy_get(path, options \\ []) do
options =
Keyword.validate!(options,
cache_path: get_cache_path(),
server_url: get_server_url(),
force: false
)

cache_path = Path.join(options[:cache_path], path)
ok? = File.exists?(cache_path)

if ok? and not options[:force] do
Logger.debug("Found existing cache entry for #{path}. Loading.")
{:ok, cache_path}
else
with {:ok, server_url} <- URI.new(options[:server_url]) do
download_url = URI.append_path(server_url, ensure_backslash(path))
download_file(download_url, cache_path)
end
def lazy_get(server, path, options \\ []) do
with {:ok, options} <- Keyword.validate(options, force: false),
do: GenServer.call(server, {:download, path, options}, :infinity)
end

@spec start_link(keyword()) :: GenServer.on_start()
def start_link(opts) do
{init_args, opts} = Keyword.split(opts, [:server_url, :cache_path])
GenServer.start_link(__MODULE__, init_args, opts)
end

@impl true
def init(opts) do
opts = Keyword.validate!(opts, cache_path: get_cache_path(), server_url: get_server_url())

with {:ok, server_url} <- URI.new(opts[:server_url]),
:ok <- File.mkdir_p(opts[:cache_path]) do
{:ok,
%{
downloads: %{},
server_url: server_url,
cache_path: opts[:cache_path],
refs: %{}
}}
end
end

@impl true
def handle_call({:download, cache_path, options}, from, state) do
file_path = Path.join(state.cache_path, cache_path)

updated_downloads =
Map.update(state.downloads, cache_path, MapSet.new([from]), &MapSet.put(&1, from))

cond do
Map.has_key?(state.downloads, cache_path) ->
{:noreply, %{state | downloads: updated_downloads}}

File.exists?(file_path) or options[:force] ->
{:reply, {:ok, file_path}, state}

true ->
ref = do_create_download_job(cache_path, state)

{:noreply,
%{state | downloads: updated_downloads, refs: Map.put(state.refs, ref, cache_path)}}
end
end

@impl true
def handle_info({ref, result}, state) do
state = emit(result, ref, state)
{:noreply, state}
end

@impl true
def handle_info({:DOWN, ref, :process, _pid, reason}, state) do
state =
if reason != :normal do
Logger.error("Task #{inspect(ref)} has crashed due to #{inspect(reason)}")
emit({:error, reason}, ref, state)
else
state
end

{:noreply, state}
end

@impl true
def handle_info(msg, state) do
Logger.warning("Received an unknown message #{inspect(msg)}. Ignoring")
{:noreply, state}
end

defp emit(message, ref, state) do
path = state.refs[ref]

state.downloads
|> Map.get(path, [])
|> Enum.each(fn from ->
GenServer.reply(from, message)
end)

%{state | refs: Map.delete(state.refs, ref), downloads: Map.delete(state.downloads, path)}
end

defp do_create_download_job(path, %{server_url: server_url, cache_path: cache_path}) do
target_file_path = Path.join(cache_path, path)
download_url = URI.append_path(server_url, ensure_backslash(path))

%Task{ref: ref} =
Task.async(fn ->
download_file(download_url, target_file_path)
end)

ref
end

@default_cache_path Application.compile_env(:ex_vision, :cache_path, "/tmp/ex_vision/cache")
defp get_cache_path() do
Application.get_env(:ex_vision, :cache_path, @default_cache_path)
end

@default_server_url Application.compile_env(
:ex_vision,
:server_url,
URI.new!("https://ai.swmansion.com/exvision/files")
)
defp get_server_url() do
Application.get_env(:ex_vision, :server_url, @default_server_url)
end

@spec download_file(URI.t(), Path.t()) ::
{:ok, Path.t()} | {:error, reason :: any()}
defp download_file(url, cache_path) do
Expand All @@ -59,6 +138,9 @@ defmodule ExVision.Cache do
end
end

defp ensure_backslash("/" <> _rest = i), do: i
defp ensure_backslash(i), do: "/" <> i

defp validate_download(path) do
if File.exists?(path),
do: :ok,
Expand All @@ -73,7 +155,8 @@ defmodule ExVision.Cache do
{:ok, _resp} ->
:ok

{:error, _reason} = error ->
{:error, reason} = error ->
Logger.error("Failed to download the file due to #{inspect(reason)}")
File.rm(target_file_path)
error
end
Expand All @@ -100,7 +183,4 @@ defmodule ExVision.Cache do
{:error, :connection_failed}
end
end

defp ensure_backslash("/" <> _rest = path), do: path
defp ensure_backslash(path), do: "/" <> path
end
2 changes: 1 addition & 1 deletion lib/ex_vision/classification/mobilenet_v3_small.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defmodule ExVision.Classification.MobileNetV3Small do
"""
use ExVision.Model.Definition.Ortex,
model: "mobilenetv3small-classifier.onnx",
categories: "imagenet_v2_categories.json"
categories: "assets/categories/imagenet_v2_categories.json"

require Bunch.Typespec
alias ExVision.Utils
Expand Down
2 changes: 1 addition & 1 deletion lib/ex_vision/detection/fasterrcnn_resnet50_fpn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule ExVision.Detection.FasterRCNN_ResNet50_FPN do
"""
use ExVision.Model.Definition.Ortex,
model: "fasterrcnn_resnet50_fpn_detector.onnx",
categories: "coco_categories.json"
categories: "assets/categories/coco_categories.json"

use ExVision.Detection.GenericDetector

Expand Down
2 changes: 1 addition & 1 deletion lib/ex_vision/detection/ssdlite320_mobilenetv3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule ExVision.Detection.Ssdlite320_MobileNetv3 do
"""
use ExVision.Model.Definition.Ortex,
model: "ssdlite320_mobilenetv3_detector.onnx",
categories: "coco_categories.json"
categories: "assets/categories/coco_categories.json"

use ExVision.Detection.GenericDetector

Expand Down
10 changes: 10 additions & 0 deletions lib/ex_vision/ex_vision.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defmodule ExVision do
@moduledoc false
use Application

@impl true
def start(_type, _args) do
children = [{ExVision.Cache, name: ExVision.Cache}]
Supervisor.start_link(children, strategy: :one_for_one)
end
end
10 changes: 4 additions & 6 deletions lib/ex_vision/model/definition/ortex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ defmodule ExVision.Model.Definition.Ortex do

- `:cache_path` - specifies a caching directory for this model.
- `:providers` - a list of desired providers, sorted by preference. Onnx will attempt to use the first available provider. If none of the provided is available, onnx will fallback to `:cpu`. Default: `[:cpu]`
- `:batch_size` - specifies a default batch size for this instance. Default: `1`
- `:batch_size` - specifies a default batch size for this instance. Default: `1`.
"""
@type load_option_t() ::
{:cache_path, Path.t()}
Expand Down Expand Up @@ -85,13 +85,11 @@ defmodule ExVision.Model.Definition.Ortex do
{:ok, ExVision.Model.t()} | {:error, atom()}
def load_ortex_model(module, model_path, options) do
with {:ok, options} <-
Keyword.validate(options, [
:cache_path,
Keyword.validate(options,
batch_size: 1,
providers: [:cpu]
]),
cache_options = Keyword.take(options, [:cache_path, :file_path]),
{:ok, path} <- ExVision.Cache.lazy_get(model_path, cache_options),
),
{:ok, path} <- ExVision.Cache.lazy_get(ExVision.Cache, model_path),
{:ok, model} <- do_load_model(path, options[:providers]) do
output_names = ExVision.Utils.onnx_output_names(model)

Expand Down
17 changes: 2 additions & 15 deletions lib/ex_vision/model/definition/parts/with_categories.ex
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
defmodule ExVision.Model.Definition.Parts.WithCategories do
@moduledoc false
require Logger
alias ExVision.{Cache, Utils}

defp get_categories(file) do
file
|> Cache.lazy_get()
|> case do
{:ok, file} ->
Utils.load_categories(file)

error ->
Logger.error("Failed to load categories from #{file} due to #{inspect(error)}")
raise "Failed to load categories from #{file}"
end
end
alias ExVision.Utils

defmacro __using__(options) do
options = Keyword.validate!(options, [:name, :categories])
categories = options |> Keyword.fetch!(:categories) |> get_categories()
categories = options |> Keyword.fetch!(:categories) |> Utils.load_categories()
spec = categories |> Enum.uniq() |> Bunch.Typespec.enum_to_alternative()

quote do
Expand Down
2 changes: 1 addition & 1 deletion lib/ex_vision/segmentation/deep_lab_v3_mobilenet_v3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule ExVision.Segmentation.DeepLabV3_MobileNetV3 do
"""
use ExVision.Model.Definition.Ortex,
model: "deeplab_v3_mobilenetv3_segmentation.onnx",
categories: "coco_with_voc_labels_categories.json"
categories: "assets/categories/coco_with_voc_labels_categories.json"

@type output_t() :: %{category_t() => Nx.Tensor.t()}

Expand Down
2 changes: 2 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ defmodule ExVision.Mixfile do

def application do
[
included_applications: [:ex_vision],
mod: {ExVision, []},
extra_applications: []
]
end
Expand Down
3 changes: 0 additions & 3 deletions models/deeplab_v3_mobilenetv3_segmentation.onnx

This file was deleted.

3 changes: 0 additions & 3 deletions models/fasterrcnn_resnet50_fpn_detector.onnx

This file was deleted.

Loading
Loading