From f815b624d28f605ee52e237fd4c9c2407c2245cd Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Wed, 18 Dec 2024 14:00:19 -0800 Subject: [PATCH] fix Signed-off-by: Prithvi Kannan --- .../src/databricks_langchain/genie.py | 5 +- .../langchain/tests/unit_tests/test_genie.py | 5 +- src/databricks_ai_bridge/genie.py | 105 ++++++++---------- tests/databricks_ai_bridge/test_genie.py | 28 ++--- 4 files changed, 69 insertions(+), 74 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 1b7c1f1..67b4cc6 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -14,6 +14,7 @@ def _concat_messages_array(messages): ) return concatenated_message + @mlflow.trace() def _query_genie_as_agent(input, genie_space_id, genie_agent_name): from langchain_core.messages import AIMessage @@ -28,8 +29,8 @@ def _query_genie_as_agent(input, genie_space_id, genie_agent_name): # Send the message and wait for a response genie_response = genie.ask_question(message) - if genie_response: - return {"messages": [AIMessage(content=genie_response)]} + if query_result := genie_response.result: + return {"messages": [AIMessage(content=query_result)]} else: return {"messages": [AIMessage(content="")]} diff --git a/integrations/langchain/tests/unit_tests/test_genie.py b/integrations/langchain/tests/unit_tests/test_genie.py index 70c6c28..024ca3d 100644 --- a/integrations/langchain/tests/unit_tests/test_genie.py +++ b/integrations/langchain/tests/unit_tests/test_genie.py @@ -1,5 +1,6 @@ from unittest.mock import patch +from databricks_ai_bridge.genie import GenieResponse from langchain_core.messages import AIMessage from databricks_langchain.genie import ( @@ -44,7 +45,7 @@ def __init__(self, role, content): def test_query_genie_as_agent(MockGenie): # Mock the Genie class and its response mock_genie = MockGenie.return_value - mock_genie.ask_question.return_value = "It is sunny." + mock_genie.ask_question.return_value = GenieResponse(result="It is sunny.") input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} result = _query_genie_as_agent(input_data, "space-id", "Genie") @@ -53,7 +54,7 @@ def test_query_genie_as_agent(MockGenie): assert result == expected_message # Test the case when genie_response is empty - mock_genie.ask_question.return_value = None + mock_genie.ask_question.return_value = GenieResponse(result=None) result = _query_genie_as_agent(input_data, "space-id", "Genie") expected_message = {"messages": [AIMessage(content="")]} diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index 8715b9e..7bbbdc5 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -1,15 +1,15 @@ import logging import time from datetime import datetime -from typing import Union +from typing import Optional, Union import mlflow import pandas as pd import tiktoken from databricks.sdk import WorkspaceClient -MAX_TOKENS_OF_DATA = 20000 # max tokens of data in markdown format -MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS +MAX_TOKENS_OF_DATA = 20000 +MAX_ITERATIONS = 50 # Define a function to count tokens @@ -17,7 +17,20 @@ def _count_tokens(text): encoding = tiktoken.encoding_for_model("gpt-4o") return len(encoding.encode(text)) -@mlflow.trace() + +class GenieResponse: + def __init__( + self, + result: Union[str, pd.DataFrame], + query: Optional[str] = "", + description: Optional[str] = "", + ): + self.result = result + self.query = query + self.description = description + + +@mlflow.trace(name="My Span") def _parse_query_result(resp) -> Union[str, pd.DataFrame]: columns = resp["manifest"]["schema"]["columns"] header = [str(col["name"]) for col in columns] @@ -41,9 +54,7 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]: row.append(float(str_value)) elif type_name == "BOOLEAN": row.append(str_value.lower() == "true") - elif type_name == "DATE": - row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) - elif type_name == "TIMESTAMP": + elif type_name == "DATE" or type_name == "TIMESTAMP": row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date()) elif type_name == "BINARY": row.append(bytes(str_value, "utf-8")) @@ -54,7 +65,6 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]: query_result = pd.DataFrame(rows, columns=header).to_markdown() - # trim down from the total rows until we get under the token limit tokens_used = _count_tokens(query_result) while tokens_used > MAX_TOKENS_OF_DATA: rows.pop() @@ -97,73 +107,56 @@ def create_message(self, conversation_id, content): @mlflow.trace() def poll_for_result(self, conversation_id, message_id): @mlflow.trace() - def poll_result(): + def poll_query_results(query, description): iteration_count = 0 while iteration_count < MAX_ITERATIONS: iteration_count += 1 resp = self.genie._api.do( "GET", - f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}", + f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result", headers=self.headers, - ) - if resp["status"] == "EXECUTING_QUERY": - query = next(r for r in resp["attachments"] if "query" in r)["query"] - description = query.get("description", "") - sql = query.get("query", "") - logging.debug(f"Description: {description}") - logging.debug(f"SQL: {sql}") - return poll_query_results() - elif resp["status"] == "COMPLETED": - # Check if there is a query object in the attachments for the COMPLETED status - query_attachment = next((r for r in resp["attachments"] if "query" in r), None) - if query_attachment: - query = query_attachment["query"] - description = query.get("description", "") - sql = query.get("query", "") - logging.debug(f"Description: {description}") - logging.debug(f"SQL: {sql}") - return poll_query_results() - else: - # Handle the text object in the COMPLETED status - return next(r for r in resp["attachments"] if "text" in r)["text"][ - "content" - ] - elif resp["status"] == "FAILED": - logging.debug("Genie failed to execute the query") - return None - elif resp["status"] == "CANCELLED": - logging.debug("Genie query cancelled") - return None - elif resp["status"] == "QUERY_RESULT_EXPIRED": - logging.debug("Genie query result expired") - return None - else: - logging.debug(f"Waiting...: {resp['status']}") + )["statement_response"] + state = resp["status"]["state"] + if state == "SUCCEEDED": + result = _parse_query_result(resp) + return GenieResponse(result, query, description) + elif state in ["RUNNING", "PENDING"]: + logging.debug("Waiting for query result...") time.sleep(5) + else: + logging.debug(f"No query result: {resp['state']}") + return GenieResponse(None, query, description) @mlflow.trace() - def poll_query_results(): + def poll_result(): iteration_count = 0 while iteration_count < MAX_ITERATIONS: iteration_count += 1 resp = self.genie._api.do( "GET", - f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result", + f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}", headers=self.headers, - )["statement_response"] - state = resp["status"]["state"] - if state == "SUCCEEDED": - return _parse_query_result(resp) - elif state == "RUNNING" or state == "PENDING": - logging.debug("Waiting for query result...") - time.sleep(5) + ) + if resp["status"] == "EXECUTING_QUERY" or resp["status"] == "COMPLETED": + query_attachment = next((r for r in resp["attachments"] if "query" in r), None) + if query_attachment: + query = query_attachment["query"]["query"] + description = query_attachment["query"].get("description", "") + return poll_query_results(query, description) + if resp["status"] == "COMPLETED": + text_content = next(r for r in resp["attachments"] if "text" in r)["text"][ + "content" + ] + return GenieResponse(result=text_content) + elif resp["status"] in ["FAILED", "CANCELLED", "QUERY_RESULT_EXPIRED"]: + logging.debug(f"Genie query {resp['status'].lower()}.") + return GenieResponse(result=None) else: - logging.debug(f"No query result: {resp['state']}") - return None + logging.debug(f"Waiting...: {resp['status']}") + time.sleep(5) return poll_result() def ask_question(self, question): resp = self.start_conversation(question) - # TODO (prithvi): return the query and the result return self.poll_for_result(resp["conversation_id"], resp["message_id"]) diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py index 0156a50..1e313db 100644 --- a/tests/databricks_ai_bridge/test_genie.py +++ b/tests/databricks_ai_bridge/test_genie.py @@ -47,8 +47,8 @@ def test_poll_for_result_completed_with_text(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ {"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]}, ] - result = genie.poll_for_result("123", "456") - assert result == "Result" + genie_result = genie.poll_for_result("123", "456") + assert genie_result.result == "Result" def test_poll_for_result_completed_with_query(genie, mock_workspace_client): @@ -64,8 +64,8 @@ def test_poll_for_result_completed_with_query(genie, mock_workspace_client): } }, ] - result = genie.poll_for_result("123", "456") - assert result == pd.DataFrame().to_markdown() + genie_result = genie.poll_for_result("123", "456") + assert genie_result.result == pd.DataFrame().to_markdown() def test_poll_for_result_executing_query(genie, mock_workspace_client): @@ -84,32 +84,32 @@ def test_poll_for_result_executing_query(genie, mock_workspace_client): } }, ] - result = genie.poll_for_result("123", "456") - assert result == pd.DataFrame().to_markdown() + genie_result = genie.poll_for_result("123", "456") + assert genie_result.result == pd.DataFrame().to_markdown() def test_poll_for_result_failed(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ {"status": "FAILED"}, ] - result = genie.poll_for_result("123", "456") - assert result is None + genie_result = genie.poll_for_result("123", "456") + assert genie_result.result is None def test_poll_for_result_cancelled(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ {"status": "CANCELLED"}, ] - result = genie.poll_for_result("123", "456") - assert result is None + genie_result = genie.poll_for_result("123", "456") + assert genie_result.result is None def test_poll_for_result_expired(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ {"status": "QUERY_RESULT_EXPIRED"}, ] - result = genie.poll_for_result("123", "456") - assert result is None + genie_result = genie.poll_for_result("123", "456") + assert genie_result.result is None def test_poll_for_result_max_iterations(genie, mock_workspace_client): @@ -148,8 +148,8 @@ def test_ask_question(genie, mock_workspace_client): {"conversation_id": "123", "message_id": "456"}, {"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]}, ] - result = genie.ask_question("What is the meaning of life?") - assert result == "Answer" + genie_result = genie.ask_question("What is the meaning of life?") + assert genie_result.result == "Answer" def test_parse_query_result_empty():