diff --git a/test/ex_vision/style_transfer/style_transfer_test.exs b/test/ex_vision/style_transfer/style_transfer_test.exs index f0e4a33..e91d20a 100644 --- a/test/ex_vision/style_transfer/style_transfer_test.exs +++ b/test/ex_vision/style_transfer/style_transfer_test.exs @@ -43,6 +43,8 @@ for {module, opts} <- TestConfiguration.configuration() do use ExVision.Model.Case, module: unquote(opts[:module]) use ExVision.TestUtils + require Logger + @impl true def test_inference_result(result) do expected_result = @@ -50,7 +52,7 @@ for {module, opts} <- TestConfiguration.configuration() do |> File.read!() |> Nx.deserialize() - assert_tensors_equal(result, expected_result) + assert_tensors_equal(result, expected_result, 5, 0.05) end end end diff --git a/test/support/exvision/test_utils.ex b/test/support/exvision/test_utils.ex index 206c296..2b8c81d 100644 --- a/test/support/exvision/test_utils.ex +++ b/test/support/exvision/test_utils.ex @@ -42,10 +42,10 @@ defmodule ExVision.TestUtils do end end - defmacro assert_tensors_equal(a, b, delta \\ @default_delta) do + defmacro assert_tensors_equal(a, b, delta \\ @default_delta, relative_delta \\ 0.0) do quote do assert unquote(a) - |> Nx.all_close(unquote(b), atol: unquote(delta)) + |> Nx.all_close(unquote(b), atol: unquote(delta), rtol: unquote(relative_delta)) |> Nx.reduce_min() |> Nx.to_number() == 1 end