Skip to content

Commit

Permalink
Reimplement the cache management to use the process workflow
Browse files Browse the repository at this point in the history
This way, we avoid race conditions related to trying to download the same file twice, when we request it before the previous download has finished.
  • Loading branch information
daniel-jodlos committed May 27, 2024
1 parent e528ea4 commit e42d52b
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 106 deletions.
182 changes: 109 additions & 73 deletions lib/ex_vision/cache.ex
Original file line number Diff line number Diff line change
@@ -1,71 +1,129 @@
defmodule ExVision.Cache do
@moduledoc false
alias ExVision.Cache.PubSub

# Module responsible for handling model file caching

use GenServer
require Logger

alias ExVision.Cache.PubSub
@type lazy_get_option_t() :: {:force, boolean()}

@default_cache_path Application.compile_env(:ex_vision, :cache_path, "/tmp/ex_vision/cache")
defp get_cache_path() do
Application.get_env(:ex_vision, :cache_path, @default_cache_path)
@doc """
Lazily evaluate the path from the cache directory.
It will only download the file if it's missing or the `force: true` option is given.
"""
@spec lazy_get(term() | pid(), Path.t(), options :: [lazy_get_option_t()]) ::
{:ok, Path.t()} | {:error, reason :: atom()}
def lazy_get(server, path, options \\ []) do
with {:ok, options} <- Keyword.validate(options, force: false),
do: GenServer.call(server, {:download, path, options}, :infinity)
end

@default_server_url Application.compile_env(
:ex_vision,
:server_url,
URI.new!("https://ai.swmansion.com/exvision/files")
)
defp get_server_url() do
Application.get_env(:ex_vision, :server_url, @default_server_url)
@spec start_link(keyword()) :: GenServer.on_start()
def start_link(opts) do
{init_args, opts} = Keyword.split(opts, [:server_url, :cache_path])
GenServer.start_link(__MODULE__, init_args, opts)
end

@type lazy_get_option_t() ::
{:cache_path, Path.t()} | {:server_url, String.t() | URI.t()} | {:force, boolean()}
@impl true
def init(opts) do
opts = Keyword.validate!(opts, cache_path: get_cache_path(), server_url: get_server_url())

with {:ok, server_url} <- URI.new(opts[:server_url]),
:ok <- File.mkdir_p(opts[:cache_path]) do
{:ok,
%{
downloads: %{},
server_url: server_url,
cache_path: opts[:cache_path],
refs: %{}
}}
end
end

@impl true
def handle_call({:download, cache_path, options}, from, state) do
file_path = Path.join(state.cache_path, cache_path)

updated_downloads =
Map.update(state.downloads, cache_path, MapSet.new([from]), &MapSet.put(&1, from))

cond do
Map.has_key?(state.downloads, cache_path) ->
{:noreply, %{state | downloads: updated_downloads}}

@spec child_spec(keyword())
def child_spec(_opts) do
Registry.child_spec(name: __MODULE__)
File.exists?(file_path) or options[:force] ->
{:reply, {:ok, file_path}, state}

true ->
ref = do_create_download_job(cache_path, state)

{:noreply,
%{state | downloads: updated_downloads, refs: Map.put(state.refs, ref, cache_path)}}
end
end

@doc """
Lazily evaluate the path from the cache directory.
It will only download the file if it's missing or the `force: true` option is given.
"""
@spec lazy_get(Path.t(), options :: [lazy_get_option_t()]) ::
{:ok, Path.t()} | {:error, reason :: atom()}
def lazy_get(path, options \\ []) do
options =
Keyword.validate!(options,
cache_path: get_cache_path(),
server_url: get_server_url(),
force: false
)

cache_path = Path.join(options[:cache_path], path)
ok? = File.exists?(cache_path)

if ok? and not options[:force] do
Logger.debug("Found existing cache entry for #{path}. Loading.")
{:ok, cache_path}
else
with {:ok, server_url} <- URI.new(options[:server_url]) do
download_url = URI.append_path(server_url, ensure_backslash(path))
download_file(download_url, cache_path)
@impl true
def handle_info({ref, result}, state) do
Logger.info("Task #{inspect(ref)} finished with #{inspect(result)}")
state = emit(result, ref, state)
{:noreply, state}
end

@impl true
def handle_info({:DOWN, ref, :process, _pid, reason}, state) do
state =
if reason != :normal do
Logger.error("Task #{inspect(ref)} has crashed due to #{inspect(reason)}")
emit({:error, reason}, ref, state)
else
state
end
end

{:noreply, state}
end

@impl true
def handle_info(msg, state) do
Logger.warning("Received an unknown message #{inspect(msg)}. Ignoring")
{:noreply, state}
end

defp create_download_job(url, cache_path) do
key = {url, cache_path}
defp emit(message, ref, state) do
path = state.refs[ref]

spawn(fn ->
PubSub.notify(__MODULE__, key, download_file(url, cache_path))
state.downloads
|> Map.get(path, [])
|> Enum.each(fn from ->
GenServer.reply(from, message)
end)

PubSub.subscribe(__MODULE__, key)
%{state | refs: Map.delete(state.refs, ref), downloads: Map.delete(state.downloads, path)}
end

defp do_create_download_job(path, %{server_url: server_url, cache_path: cache_path}) do
target_file_path = Path.join(cache_path, path)
download_url = URI.append_path(server_url, ensure_backslash(path))

%Task{ref: ref} =
Task.async(fn ->
download_file(download_url, target_file_path)
end)

ref
end

@default_cache_path Application.compile_env(:ex_vision, :cache_path, "/tmp/ex_vision/cache")
defp get_cache_path() do
Application.get_env(:ex_vision, :cache_path, @default_cache_path)
end

@default_server_url Application.compile_env(
:ex_vision,
:server_url,
URI.new!("https://ai.swmansion.com/exvision/files")
)
defp get_server_url() do
Application.get_env(:ex_vision, :server_url, @default_server_url)
end

@spec download_file(URI.t(), Path.t()) ::
Expand All @@ -81,6 +139,9 @@ defmodule ExVision.Cache do
end
end

defp ensure_backslash("/" <> _rest = i), do: i
defp ensure_backslash(i), do: "/" <> i

defp validate_download(path) do
if File.exists?(path),
do: :ok,
Expand Down Expand Up @@ -123,29 +184,4 @@ defmodule ExVision.Cache do
{:error, :connection_failed}
end
end

defp ensure_backslash("/" <> _rest = path), do: path
defp ensure_backslash(path), do: "/" <> path
end

defmodule ExVision.Cache.PubSub do
@moduledoc false

@spec subscribe(term(), Path.t()) :: :ok
def subscribe(registry, key) do
Registry.register(registry, key, [])

receive do
{^registry, :notification, result} ->
Registry.unregister(registry, key)
result
end
end

@spec notify(term(), Path.t()) :: :ok
def notify(registry, key, result) do
Registry.dispatch(registry, key, fn entries ->
for {pid, _value} <- entries, do: send(pid, {registry, :notification, result})
end)
end
end
10 changes: 10 additions & 0 deletions lib/ex_vision/ex_vision.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defmodule ExVision do
@moduledoc false
use Application

@impl true
def start(_type, _args) do
children = [{ExVision.Cache, name: ExVision.Cache}]
Supervisor.start_link(children, strategy: :one_for_one)
end
end
8 changes: 3 additions & 5 deletions lib/ex_vision/model/definition/ortex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,11 @@ defmodule ExVision.Model.Definition.Ortex do
{:ok, ExVision.Model.t()} | {:error, atom()}
def load_ortex_model(module, model_path, options) do
with {:ok, options} <-
Keyword.validate(options, [
:cache_path,
Keyword.validate(options,
batch_size: 1,
providers: [:cpu]
]),
cache_options = Keyword.take(options, [:cache_path, :file_path]),
{:ok, path} <- ExVision.Cache.lazy_get(model_path, cache_options),
),
{:ok, path} <- ExVision.Cache.lazy_get(ExVision.Cache, model_path),
{:ok, model} <- do_load_model(path, options[:providers]) do
output_names = ExVision.Utils.onnx_output_names(model)

Expand Down
2 changes: 2 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ defmodule ExVision.Mixfile do

def application do
[
included_applications: [:ex_vision],
mod: {ExVision, []},
extra_applications: []
]
end
Expand Down
44 changes: 21 additions & 23 deletions test/ex_vision/cache_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@ defmodule ExVision.CacheTest do

@moduletag :tmp_dir

setup %{tmp_dir: tmp_dir} do
app_env_override(:server_url, URI.new!("http://mock_server:8000"))
app_env_override(:cache_path, tmp_dir)
end

setup ctx do
files =
Map.get(ctx, :files, %{
"/test" => rand_string(256)
})

set_mimic_global()

stub(Req, :get, fn
%URI{host: "mock_server", port: 8000, path: path}, options ->
options = Keyword.validate!(options, [:raw, :into])
Expand All @@ -39,34 +36,35 @@ defmodule ExVision.CacheTest do
[files: files]
end

setup %{tmp_dir: tmp_dir} do
{:ok, _cache} =
Cache.start_link(
name: MyCache,
server_url: URI.new!("http://mock_server:8000"),
cache_path: tmp_dir
)

:ok
end

test "Can download the file", ctx do
[{path, expected_contents}] = Enum.to_list(ctx.files)
expected_path = Path.join(ctx.tmp_dir, path)
assert {:ok, ^expected_path} = Cache.lazy_get(path)
assert {:ok, ^expected_path} = Cache.lazy_get(MyCache, path)
verify_download(expected_path, expected_contents)
end

test "will fail if server is unreachable" do
app_env_override(:server_url, URI.new!("http://localhost:9999"))
assert {:error, :connection_failed} = Cache.lazy_get("/test")
assert {:error, :connection_failed} = Cache.lazy_get("/test")
end
url = "http://localhost:9999"
{:ok, c} = Cache.start_link(server_url: url, name: nil)

test "will fail if we request file that doesn't exist" do
assert {:error, :doesnt_exist} = Cache.lazy_get("/idk")
assert {:error, :doesnt_exist} = Cache.lazy_get("/idk")
assert {:error, :connection_failed} = Cache.lazy_get(c, "/test")
assert {:error, :connection_failed} = Cache.lazy_get(c, "/test")
end

defp app_env_override(key, new_value) do
original = Application.fetch_env(:ex_vision, key)
Application.put_env(:ex_vision, key, new_value)

on_exit(fn ->
case original do
{:ok, value} -> Application.put_env(:ex_vision, key, value)
:error -> Application.delete_env(:ex_vision, key)
end
end)
test "will fail if we request file that doesn't exist" do
assert {:error, :doesnt_exist} = Cache.lazy_get(MyCache, "/idk")
assert {:error, :doesnt_exist} = Cache.lazy_get(MyCache, "/idk")
end

defp verify_download(path, expected_contents) do
Expand Down
9 changes: 4 additions & 5 deletions test/support/exvision/model/case.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ defmodule ExVision.Model.Case do
@behaviour ExVision.Model.Case

setup_all do
{:ok, model} = unquote(opts[:module]).load(cache_path: "models")
{:ok, model} = unquote(opts[:module]).load()
[model: model]
end

Expand All @@ -34,7 +34,7 @@ defmodule ExVision.Model.Case do
end

test "child_spec/1" do
assert spec = unquote(opts[:module]).child_spec(cache_path: "models")
assert spec = unquote(opts[:module]).child_spec()
end

describe "stateful/process workflow" do
Expand All @@ -44,7 +44,7 @@ defmodule ExVision.Model.Case do

{:ok, _supervisor} =
Supervisor.start_link(
[unquote(opts[:module]).child_spec(name: name, cache_path: "models")],
[unquote(opts[:module]).child_spec(name: name)],
strategy: :one_for_one
)

Expand All @@ -69,8 +69,7 @@ defmodule ExVision.Model.Case do
name: __MODULE__.TestProcess1,
batch_size: 8,
batch_timeout: 10,
partitions: true,
cache_path: "models"
partitions: true
]

child_spec = {unquote(opts[:module]), options}
Expand Down

0 comments on commit e42d52b

Please sign in to comment.