diff --git a/lib/ex_vision/utils.ex b/lib/ex_vision/utils.ex index 1003257..e8813c3 100644 --- a/lib/ex_vision/utils.ex +++ b/lib/ex_vision/utils.ex @@ -86,8 +86,7 @@ defmodule ExVision.Utils do defp ensure_grad_3(tensor) do tensor - - Nx.rank() + |> Nx.rank() |> case do 3 -> [tensor] 4 -> tensor |> Nx.to_batched(1) |> Stream.map(&Nx.squeeze(&1, axes: [0])) |> Enum.to_list()