diff --git a/test/ex_vision/style_transfer/style_transfer_test.exs b/test/ex_vision/style_transfer/style_transfer_test.exs index f0e4a33..33fce37 100644 --- a/test/ex_vision/style_transfer/style_transfer_test.exs +++ b/test/ex_vision/style_transfer/style_transfer_test.exs @@ -43,6 +43,8 @@ for {module, opts} <- TestConfiguration.configuration() do use ExVision.Model.Case, module: unquote(opts[:module]) use ExVision.TestUtils + require Logger + @impl true def test_inference_result(result) do expected_result = @@ -50,6 +52,14 @@ for {module, opts} <- TestConfiguration.configuration() do |> File.read!() |> Nx.deserialize() + diff_sum = + expected_result + |> Nx.subtract(result) + |> Nx.reduce(0, fn x, y -> Nx.add(x, y) end) + |> Nx.to_number() + + assert diff_sum == 1 + assert_tensors_equal(result, expected_result) end end diff --git a/test/support/exvision/test_utils.ex b/test/support/exvision/test_utils.ex index 206c296..c7b6d27 100644 --- a/test/support/exvision/test_utils.ex +++ b/test/support/exvision/test_utils.ex @@ -1,4 +1,5 @@ defmodule ExVision.TestUtils do + require Nx @moduledoc false import ExUnit.Assertions, only: :macros