Skip to content

Commit

Permalink
Clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak committed Jul 3, 2024
1 parent c1a32ee commit 418def9
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 17 deletions.
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_l.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_L do
image
|> ExVision.Utils.resize({480, 480})
|> NxImage.normalize(
Nx.tensor([0.5, 0.5, 0.5]),
Nx.tensor([0.5, 0.5, 0.5]),
Nx.f32([0.5, 0.5, 0.5]),
Nx.f32([0.5, 0.5, 0.5]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_m.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_M do
image
|> ExVision.Utils.resize({480, 480})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/efficientnet_v2_s.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.EfficientNet_V2_S do
image
|> ExVision.Utils.resize({384, 384})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/ex_vision/classification/generic_classifier.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ defmodule ExVision.Classification.GenericClassifier do
scores
|> Nx.backend_transfer()
|> Nx.flatten()
|> Utils.softmax()
|> Axon.Activations.softmax(axis: [0])
|> Nx.to_flat_list()
|> then(&Enum.zip(categories, &1))
|> Map.new()
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/mobilenet_v3_small.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.MobileNetV3Small do
image
|> ExVision.Utils.resize({224, 224})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
4 changes: 2 additions & 2 deletions lib/ex_vision/classification/squeezenet1_1.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule ExVision.Classification.SqueezeNet1_1 do
image
|> ExVision.Utils.resize({224, 224})
|> NxImage.normalize(
Nx.tensor([0.485, 0.456, 0.406]),
Nx.tensor([0.229, 0.224, 0.225]),
Nx.f32([0.485, 0.456, 0.406]),
Nx.f32([0.229, 0.224, 0.225]),
channels: :first
)
end
Expand Down
8 changes: 2 additions & 6 deletions lib/ex_vision/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ defmodule ExVision.Utils do

defp ensure_grad_3(tensor) do
tensor
|> Nx.shape()
|> tuple_size()

Nx.rank()
|> case do
3 -> [tensor]
4 -> tensor |> Nx.to_batched(1) |> Stream.map(&Nx.squeeze(&1, axes: [0])) |> Enum.to_list()
Expand Down Expand Up @@ -149,10 +149,6 @@ defmodule ExVision.Utils do
Enum.map(outputs, fn {name, _type, _shape} -> name end)
end

defn softmax(x) do
Nx.divide(Nx.exp(x), Nx.sum(Nx.exp(x)))
end

@spec batched_run(atom(), ExVision.Model.input_t()) :: ExVision.Model.output_t()
def batched_run(process_name, input) when is_list(input) do
Nx.Serving.batched_run(process_name, input)
Expand Down

0 comments on commit 418def9

Please sign in to comment.