From 37eba4389f856a36101d6125a3f0720e3babb0b1 Mon Sep 17 00:00:00 2001 From: Shreyas Damle <57351398+shreyas-damle@users.noreply.github.com> Date: Mon, 22 Jul 2024 11:56:24 +0530 Subject: [PATCH] Sort Apps on app listing page based on findings (#417) * remove topics from prompt data * Sort retrievals data * Ruff Fixes * UT fix. --- pebblo/app/models/models.py | 14 +++++++++++--- pebblo/app/service/local_ui_service.py | 19 +++++++++++++++++++ pebblo/app/service/prompt_service.py | 22 +++++++++++----------- tests/app/test_prompt_api.py | 1 - 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/pebblo/app/models/models.py b/pebblo/app/models/models.py index c65f1ae5..d9d5af71 100644 --- a/pebblo/app/models/models.py +++ b/pebblo/app/models/models.py @@ -26,8 +26,12 @@ class AiDataModel(BaseModel): data: Optional[Union[list, str]] entityCount: int entities: dict - topicCount: int - topics: dict + topicCount: Optional[int] = None + topics: Optional[dict] = None + + def dict(self, **kwargs): + kwargs["exclude_none"] = True + return super().dict(**kwargs) class AiDocs(BaseModel): @@ -100,7 +104,11 @@ class RetrievalContext(BaseModel): class AiClassificationData(BaseModel): entities: dict - topics: dict + topics: Optional[dict] = None + + def dict(self, **kwargs): + kwargs["exclude_none"] = True + return super().dict(**kwargs) class RetrievalData(BaseModel): diff --git a/pebblo/app/service/local_ui_service.py b/pebblo/app/service/local_ui_service.py index 37b021a7..0f88b8bb 100644 --- a/pebblo/app/service/local_ui_service.py +++ b/pebblo/app/service/local_ui_service.py @@ -134,6 +134,12 @@ def prepare_retrieval_response(self, app_dir, app_json): logger.warning(f"Skipping app '{app_dir}' due to missing or invalid file") return + # Sort retrievals data + retrievals = self.sort_retrievals_data( + app_metadata_content.get("retrievals", []) + ) + app_metadata_content["retrievals"] = retrievals + # fetch total retrievals for retrieval in app_metadata_content.get("retrievals", []): retrieval_data = {"name": app_json.get("name")} @@ -336,6 +342,8 @@ def get_latest_load_id(load_ids, app_dir): def get_retrieval_app_details(self, app_content): retrieval_data = app_content.get("retrievals", []) + retrieval_data = self.sort_retrievals_data(retrieval_data) + active_users = self.get_active_users(retrieval_data) documents = self.get_all_documents(retrieval_data) vector_dbs = self.get_all_vector_dbs(retrieval_data) @@ -474,6 +482,17 @@ def sort_retrievals(retrieval_data: list, search_key: str) -> dict: return sorted_resp + @staticmethod + def _calculate_total_count(item: dict): + prompt_count = item.get("prompt", {}).get("entityCount") or 0 + response_count = item.get("prompt", {}).get("entityCount") or 0 + return prompt_count + response_count + + def sort_retrievals_data(self, retrieval): + # Sort the list based on the total count in descending order + sorted_data = sorted(retrieval, key=self._calculate_total_count, reverse=True) + return sorted_data + def get_all_documents(self, retrieval_data: list) -> dict: """ This function returns documents per app with its metadata in following format: diff --git a/pebblo/app/service/prompt_service.py b/pebblo/app/service/prompt_service.py index 45064b4c..b374866b 100644 --- a/pebblo/app/service/prompt_service.py +++ b/pebblo/app/service/prompt_service.py @@ -33,7 +33,7 @@ def __init__(self, data: dict): self.entity_classifier_obj = EntityClassifier() self.topic_classifier_obj = TopicClassifier() - def _fetch_classified_data(self, input_data): + def _fetch_classified_data(self, input_data, input_type=""): """ Retrieve input data and return its corresponding model object with classification. """ @@ -47,15 +47,15 @@ def _fetch_classified_data(self, input_data): ) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer( input_data ) - topics, topic_count = self.topic_classifier_obj.predict(input_data) - - data = { - "data": input_data, - "entityCount": entity_count, - "entities": entities, - "topicCount": topic_count, - "topics": topics, - } + + data = {"data": input_data, "entityCount": entity_count, "entities": entities} + + # Topic classification is performed only for the response. + if input_type == "response": + topics, topic_count = self.topic_classifier_obj.predict(input_data) + data["topicCount"] = topic_count + data["topics"] = topics + logger.debug(f"AI_APPS [{self.application_name}]:Classified Details: {data}") return data @@ -154,7 +154,7 @@ def process_request(self): # getting response data response_data = self._fetch_classified_data( - self.data.get("response", {}).get("data") + self.data.get("response", {}).get("data"), input_type="response" ) # getting retrieval context data diff --git a/tests/app/test_prompt_api.py b/tests/app/test_prompt_api.py index 08c36f15..0cf95e04 100644 --- a/tests/app/test_prompt_api.py +++ b/tests/app/test_prompt_api.py @@ -52,7 +52,6 @@ def test_app_prompt_success(mock_write_json_to_file): assert response.json()["message"] == "AiApp prompt request completed successfully" assert response.json()["retrieval_data"]["prompt"] == { "entities": {}, - "topics": {}, } assert response.json()["retrieval_data"]["response"] == { "entities": {"us-ssn": 1},