diff --git a/lib/ex_vision/classification/generic_classifier.ex b/lib/ex_vision/classification/generic_classifier.ex index 64b8b0a..5e7b9d1 100644 --- a/lib/ex_vision/classification/generic_classifier.ex +++ b/lib/ex_vision/classification/generic_classifier.ex @@ -1,6 +1,16 @@ defmodule ExVision.Classification.GenericClassifier do + @moduledoc false + + # Contains a default implementation of post processing for TorchVision classifiers + # To use: `use ExVision.Classification.GenericClassifier` + alias ExVision.Utils + alias ExVision.Types.ImageMetadata + + @typep output_t() :: %{atom() => number()} + + @spec postprocessing(map(), ImageMetadata.t(), [atom()]) :: output_t() def postprocessing(%{"output" => scores}, _metadata, categories) do scores |> Nx.backend_transfer()