diff --git a/lib/ex_vision/style_transfer/style_transfer.ex b/lib/ex_vision/style_transfer/style_transfer.ex index 07ba722..297dc9c 100644 --- a/lib/ex_vision/style_transfer/style_transfer.ex +++ b/lib/ex_vision/style_transfer/style_transfer.ex @@ -66,9 +66,8 @@ for {module, opts} <- Configuration.configuration() do stylized_frame["55"] |> Nx.reshape({3, h, w}, names: [:channel, :height, :width]) - |> NxImage.resize(metadata.original_size, channels: :first) - |> Nx.max(0.0) - |> Nx.min(255.0) + |> NxImage.resize(metadata.original_size, channels: :first, method: :bilinear) + |> Nx.clip(0.0, 255.0) |> Nx.as_type(:u8) |> Nx.transpose(axes: [1, 2, 0]) end diff --git a/test/ex_vision/style_transfer/style_transfer_test.exs b/test/ex_vision/style_transfer/style_transfer_test.exs index 21e7c9b..e91d20a 100644 --- a/test/ex_vision/style_transfer/style_transfer_test.exs +++ b/test/ex_vision/style_transfer/style_transfer_test.exs @@ -52,24 +52,6 @@ for {module, opts} <- TestConfiguration.configuration() do |> File.read!() |> Nx.deserialize() - diff_sum = - expected_result - |> Nx.equal(result) - |> Nx.as_type(:u64) - |> Nx.reduce(0, fn x, y -> Nx.add(x, y) end) - # |> Nx.reduce_max() - |> Nx.to_number() - - diff_sum2 = - expected_result - |> Nx.subtract(result) - |> Nx.abs() - |> Nx.as_type(:u64) - |> Nx.reduce(0, fn x, y -> Nx.add(x, y) end) - # |> Nx.reduce_max() - |> Nx.to_number() - assert diff_sum == diff_sum2 - assert_tensors_equal(result, expected_result, 5, 0.05) end end diff --git a/test/support/exvision/test_utils.ex b/test/support/exvision/test_utils.ex index 2b8c81d..8e95244 100644 --- a/test/support/exvision/test_utils.ex +++ b/test/support/exvision/test_utils.ex @@ -44,10 +44,23 @@ defmodule ExVision.TestUtils do defmacro assert_tensors_equal(a, b, delta \\ @default_delta, relative_delta \\ 0.0) do quote do - assert unquote(a) - |> Nx.all_close(unquote(b), atol: unquote(delta), rtol: unquote(relative_delta)) - |> Nx.reduce_min() - |> Nx.to_number() == 1 + value_condition = + unquote(a) + |> Nx.all_close(unquote(b), atol: unquote(delta), rtol: unquote(relative_delta)) + |> Nx.reduce_min() + |> Nx.to_number() == 1 + + equal_on_count = + unquote(a) + |> Nx.equal(unquote(b)) + |> Nx.as_type(:u64) + |> Nx.reduce(0, fn x, y -> Nx.add(x, y) end) + |> Nx.to_number() + + number_count = unquote(a) |> Nx.shape() |> Tuple.product() + proportional_condition = equal_on_count / number_count > 0.99 + + assert value_condition or proportional_condition end end