Skip to content

Commit

Permalink
Sort Apps on app listing page based on findings (#417)
Browse files Browse the repository at this point in the history
* remove topics from prompt data

* Sort retrievals data

* Ruff Fixes

* UT fix.
  • Loading branch information
shreyas-damle authored Jul 22, 2024
1 parent 7919208 commit 37eba43
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 15 deletions.
14 changes: 11 additions & 3 deletions pebblo/app/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ class AiDataModel(BaseModel):
data: Optional[Union[list, str]]
entityCount: int
entities: dict
topicCount: int
topics: dict
topicCount: Optional[int] = None
topics: Optional[dict] = None

def dict(self, **kwargs):
kwargs["exclude_none"] = True
return super().dict(**kwargs)


class AiDocs(BaseModel):
Expand Down Expand Up @@ -100,7 +104,11 @@ class RetrievalContext(BaseModel):

class AiClassificationData(BaseModel):
entities: dict
topics: dict
topics: Optional[dict] = None

def dict(self, **kwargs):
kwargs["exclude_none"] = True
return super().dict(**kwargs)


class RetrievalData(BaseModel):
Expand Down
19 changes: 19 additions & 0 deletions pebblo/app/service/local_ui_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def prepare_retrieval_response(self, app_dir, app_json):
logger.warning(f"Skipping app '{app_dir}' due to missing or invalid file")
return

# Sort retrievals data
retrievals = self.sort_retrievals_data(
app_metadata_content.get("retrievals", [])
)
app_metadata_content["retrievals"] = retrievals

# fetch total retrievals
for retrieval in app_metadata_content.get("retrievals", []):
retrieval_data = {"name": app_json.get("name")}
Expand Down Expand Up @@ -336,6 +342,8 @@ def get_latest_load_id(load_ids, app_dir):
def get_retrieval_app_details(self, app_content):
retrieval_data = app_content.get("retrievals", [])

retrieval_data = self.sort_retrievals_data(retrieval_data)

active_users = self.get_active_users(retrieval_data)
documents = self.get_all_documents(retrieval_data)
vector_dbs = self.get_all_vector_dbs(retrieval_data)
Expand Down Expand Up @@ -474,6 +482,17 @@ def sort_retrievals(retrieval_data: list, search_key: str) -> dict:

return sorted_resp

@staticmethod
def _calculate_total_count(item: dict):
prompt_count = item.get("prompt", {}).get("entityCount") or 0
response_count = item.get("prompt", {}).get("entityCount") or 0
return prompt_count + response_count

def sort_retrievals_data(self, retrieval):
# Sort the list based on the total count in descending order
sorted_data = sorted(retrieval, key=self._calculate_total_count, reverse=True)
return sorted_data

def get_all_documents(self, retrieval_data: list) -> dict:
"""
This function returns documents per app with its metadata in following format:
Expand Down
22 changes: 11 additions & 11 deletions pebblo/app/service/prompt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, data: dict):
self.entity_classifier_obj = EntityClassifier()
self.topic_classifier_obj = TopicClassifier()

def _fetch_classified_data(self, input_data):
def _fetch_classified_data(self, input_data, input_type=""):
"""
Retrieve input data and return its corresponding model object with classification.
"""
Expand All @@ -47,15 +47,15 @@ def _fetch_classified_data(self, input_data):
) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
input_data
)
topics, topic_count = self.topic_classifier_obj.predict(input_data)

data = {
"data": input_data,
"entityCount": entity_count,
"entities": entities,
"topicCount": topic_count,
"topics": topics,
}

data = {"data": input_data, "entityCount": entity_count, "entities": entities}

# Topic classification is performed only for the response.
if input_type == "response":
topics, topic_count = self.topic_classifier_obj.predict(input_data)
data["topicCount"] = topic_count
data["topics"] = topics

logger.debug(f"AI_APPS [{self.application_name}]:Classified Details: {data}")
return data

Expand Down Expand Up @@ -154,7 +154,7 @@ def process_request(self):

# getting response data
response_data = self._fetch_classified_data(
self.data.get("response", {}).get("data")
self.data.get("response", {}).get("data"), input_type="response"
)

# getting retrieval context data
Expand Down
1 change: 0 additions & 1 deletion tests/app/test_prompt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_app_prompt_success(mock_write_json_to_file):
assert response.json()["message"] == "AiApp prompt request completed successfully"
assert response.json()["retrieval_data"]["prompt"] == {
"entities": {},
"topics": {},
}
assert response.json()["retrieval_data"]["response"] == {
"entities": {"us-ssn": 1},
Expand Down

0 comments on commit 37eba43

Please sign in to comment.