From fa3062b5b3bcf69d7dea22eeced866f6f186dd08 Mon Sep 17 00:00:00 2001 From: Patrick Leary Date: Tue, 31 Dec 2024 14:43:05 -0500 Subject: [PATCH] returning embedding is optional with return_embedding parameter --- lib/inat_vision_api.py | 8 ++++++-- lib/inat_vision_api_responses.py | 16 ++++++++++------ lib/templates/home.html | 7 ++++++- lib/web_forms.py | 1 + 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/lib/inat_vision_api.py b/lib/inat_vision_api.py index 8150149..6021224 100644 --- a/lib/inat_vision_api.py +++ b/lib/inat_vision_api.py @@ -153,9 +153,11 @@ def score_image(self, form, file_path, lat, lng, iconic_taxon_id, geomodel): return InatVisionAPIResponses.aggregated_tree_response( aggregated_scores, self.inferrer ) + embedding = self.inferrer.signature_for_image(file_path) if \ + form.return_embedding.data == "true" else None return InatVisionAPIResponses.aggregated_object_response( leaf_scores, aggregated_scores, self.inferrer, - embedding=self.inferrer.signature_for_image(file_path) + embedding=embedding ) # legacy dict response @@ -163,9 +165,11 @@ def score_image(self, form, file_path, lat, lng, iconic_taxon_id, geomodel): return InatVisionAPIResponses.legacy_dictionary_response(leaf_scores, self.inferrer) if form.format.data == "object": + embedding = self.inferrer.signature_for_image(file_path) if \ + form.return_embedding.data == "true" else None return InatVisionAPIResponses.object_response( leaf_scores, self.inferrer, - embedding=self.inferrer.signature_for_image(file_path) + embedding=embedding ) return InatVisionAPIResponses.array_response(leaf_scores, self.inferrer) diff --git a/lib/inat_vision_api_responses.py b/lib/inat_vision_api_responses.py index e80c9ea..ad00ffe 100644 --- a/lib/inat_vision_api_responses.py +++ b/lib/inat_vision_api_responses.py @@ -20,7 +20,7 @@ def array_response(leaf_scores, inferrer): return InatVisionAPIResponses.array_response_columns(leaf_scores).to_dict(orient="records") @staticmethod - def object_response(leaf_scores, inferrer, embedding): + def object_response(leaf_scores, inferrer, embedding=None): leaf_scores = InatVisionAPIResponses.limit_leaf_scores_for_response(leaf_scores) leaf_scores = InatVisionAPIResponses.update_leaf_scores_scaling(leaf_scores) results = InatVisionAPIResponses.array_response_columns( @@ -39,11 +39,13 @@ def object_response(leaf_scores, inferrer, embedding): common_ancestor_frame ).to_dict(orient="records")[0] - return { + response = { "common_ancestor": common_ancestor, "results": results, - "embedding": embedding } + if embedding is not None: + response["embedding"] = embedding + return response @staticmethod def aggregated_tree_response(aggregated_scores, inferrer): @@ -74,7 +76,7 @@ def aggregated_tree_response(aggregated_scores, inferrer): return "
" + "
".join(printable_tree) + "
" @staticmethod - def aggregated_object_response(leaf_scores, aggregated_scores, inferrer, embedding): + def aggregated_object_response(leaf_scores, aggregated_scores, inferrer, embedding=None): top_leaf_combined_score = aggregated_scores.query( "leaf_class_id.notnull()" ).sort_values( @@ -117,11 +119,13 @@ def aggregated_object_response(leaf_scores, aggregated_scores, inferrer, embeddi common_ancestor_frame ).to_dict(orient="records")[0] - return { + response = { "common_ancestor": common_ancestor, "results": final_results.to_dict(orient="records"), - "embedding": embedding } + if embedding is not None: + response["embedding"] = embedding + return response @staticmethod def limit_leaf_scores_for_response(leaf_scores): diff --git a/lib/templates/home.html b/lib/templates/home.html index 41854b3..1bb7463 100644 --- a/lib/templates/home.html +++ b/lib/templates/home.html @@ -23,9 +23,9 @@

Slim vs Legacy Model

Lng:


+ +

diff --git a/lib/web_forms.py b/lib/web_forms.py index 2c42768..e0e0e67 100644 --- a/lib/web_forms.py +++ b/lib/web_forms.py @@ -13,4 +13,5 @@ class ImageForm(FlaskForm): taxon_id = StringField("taxon_id") geomodel = StringField("geomodel") aggregated = StringField("aggregated") + return_embedding = StringField("return_embedding") format = StringField("format")