Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Prithvi Kannan <[email protected]>
  • Loading branch information
prithvikannan committed Dec 18, 2024
1 parent 37ad675 commit f815b62
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 74 deletions.
5 changes: 3 additions & 2 deletions integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def _concat_messages_array(messages):
)
return concatenated_message


@mlflow.trace()
def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
from langchain_core.messages import AIMessage
Expand All @@ -28,8 +29,8 @@ def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
# Send the message and wait for a response
genie_response = genie.ask_question(message)

if genie_response:
return {"messages": [AIMessage(content=genie_response)]}
if query_result := genie_response.result:
return {"messages": [AIMessage(content=query_result)]}
else:
return {"messages": [AIMessage(content="")]}

Expand Down
5 changes: 3 additions & 2 deletions integrations/langchain/tests/unit_tests/test_genie.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import patch

from databricks_ai_bridge.genie import GenieResponse
from langchain_core.messages import AIMessage

from databricks_langchain.genie import (
Expand Down Expand Up @@ -44,7 +45,7 @@ def __init__(self, role, content):
def test_query_genie_as_agent(MockGenie):
# Mock the Genie class and its response
mock_genie = MockGenie.return_value
mock_genie.ask_question.return_value = "It is sunny."
mock_genie.ask_question.return_value = GenieResponse(result="It is sunny.")

input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]}
result = _query_genie_as_agent(input_data, "space-id", "Genie")
Expand All @@ -53,7 +54,7 @@ def test_query_genie_as_agent(MockGenie):
assert result == expected_message

# Test the case when genie_response is empty
mock_genie.ask_question.return_value = None
mock_genie.ask_question.return_value = GenieResponse(result=None)
result = _query_genie_as_agent(input_data, "space-id", "Genie")

expected_message = {"messages": [AIMessage(content="")]}
Expand Down
105 changes: 49 additions & 56 deletions src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
import logging
import time
from datetime import datetime
from typing import Union
from typing import Optional, Union

import mlflow
import pandas as pd
import tiktoken
from databricks.sdk import WorkspaceClient

MAX_TOKENS_OF_DATA = 20000 # max tokens of data in markdown format
MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS
MAX_TOKENS_OF_DATA = 20000
MAX_ITERATIONS = 50


# Define a function to count tokens
def _count_tokens(text):
encoding = tiktoken.encoding_for_model("gpt-4o")
return len(encoding.encode(text))

@mlflow.trace()

class GenieResponse:
def __init__(
self,
result: Union[str, pd.DataFrame],
query: Optional[str] = "",
description: Optional[str] = "",
):
self.result = result
self.query = query
self.description = description


@mlflow.trace(name="My Span")
def _parse_query_result(resp) -> Union[str, pd.DataFrame]:
columns = resp["manifest"]["schema"]["columns"]
header = [str(col["name"]) for col in columns]
Expand All @@ -41,9 +54,7 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]:
row.append(float(str_value))
elif type_name == "BOOLEAN":
row.append(str_value.lower() == "true")
elif type_name == "DATE":
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date())
elif type_name == "TIMESTAMP":
elif type_name == "DATE" or type_name == "TIMESTAMP":
row.append(datetime.strptime(str_value[:10], "%Y-%m-%d").date())
elif type_name == "BINARY":
row.append(bytes(str_value, "utf-8"))
Expand All @@ -54,7 +65,6 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]:

query_result = pd.DataFrame(rows, columns=header).to_markdown()

# trim down from the total rows until we get under the token limit
tokens_used = _count_tokens(query_result)
while tokens_used > MAX_TOKENS_OF_DATA:
rows.pop()
Expand Down Expand Up @@ -97,73 +107,56 @@ def create_message(self, conversation_id, content):
@mlflow.trace()
def poll_for_result(self, conversation_id, message_id):
@mlflow.trace()
def poll_result():
def poll_query_results(query, description):
iteration_count = 0
while iteration_count < MAX_ITERATIONS:
iteration_count += 1
resp = self.genie._api.do(
"GET",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result",
headers=self.headers,
)
if resp["status"] == "EXECUTING_QUERY":
query = next(r for r in resp["attachments"] if "query" in r)["query"]
description = query.get("description", "")
sql = query.get("query", "")
logging.debug(f"Description: {description}")
logging.debug(f"SQL: {sql}")
return poll_query_results()
elif resp["status"] == "COMPLETED":
# Check if there is a query object in the attachments for the COMPLETED status
query_attachment = next((r for r in resp["attachments"] if "query" in r), None)
if query_attachment:
query = query_attachment["query"]
description = query.get("description", "")
sql = query.get("query", "")
logging.debug(f"Description: {description}")
logging.debug(f"SQL: {sql}")
return poll_query_results()
else:
# Handle the text object in the COMPLETED status
return next(r for r in resp["attachments"] if "text" in r)["text"][
"content"
]
elif resp["status"] == "FAILED":
logging.debug("Genie failed to execute the query")
return None
elif resp["status"] == "CANCELLED":
logging.debug("Genie query cancelled")
return None
elif resp["status"] == "QUERY_RESULT_EXPIRED":
logging.debug("Genie query result expired")
return None
else:
logging.debug(f"Waiting...: {resp['status']}")
)["statement_response"]
state = resp["status"]["state"]
if state == "SUCCEEDED":
result = _parse_query_result(resp)
return GenieResponse(result, query, description)
elif state in ["RUNNING", "PENDING"]:
logging.debug("Waiting for query result...")
time.sleep(5)
else:
logging.debug(f"No query result: {resp['state']}")
return GenieResponse(None, query, description)

@mlflow.trace()
def poll_query_results():
def poll_result():
iteration_count = 0
while iteration_count < MAX_ITERATIONS:
iteration_count += 1
resp = self.genie._api.do(
"GET",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result",
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
headers=self.headers,
)["statement_response"]
state = resp["status"]["state"]
if state == "SUCCEEDED":
return _parse_query_result(resp)
elif state == "RUNNING" or state == "PENDING":
logging.debug("Waiting for query result...")
time.sleep(5)
)
if resp["status"] == "EXECUTING_QUERY" or resp["status"] == "COMPLETED":
query_attachment = next((r for r in resp["attachments"] if "query" in r), None)
if query_attachment:
query = query_attachment["query"]["query"]
description = query_attachment["query"].get("description", "")
return poll_query_results(query, description)
if resp["status"] == "COMPLETED":
text_content = next(r for r in resp["attachments"] if "text" in r)["text"][
"content"
]
return GenieResponse(result=text_content)
elif resp["status"] in ["FAILED", "CANCELLED", "QUERY_RESULT_EXPIRED"]:
logging.debug(f"Genie query {resp['status'].lower()}.")
return GenieResponse(result=None)
else:
logging.debug(f"No query result: {resp['state']}")
return None
logging.debug(f"Waiting...: {resp['status']}")
time.sleep(5)

return poll_result()

def ask_question(self, question):
resp = self.start_conversation(question)
# TODO (prithvi): return the query and the result
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
28 changes: 14 additions & 14 deletions tests/databricks_ai_bridge/test_genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_poll_for_result_completed_with_text(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "COMPLETED", "attachments": [{"text": {"content": "Result"}}]},
]
result = genie.poll_for_result("123", "456")
assert result == "Result"
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result == "Result"


def test_poll_for_result_completed_with_query(genie, mock_workspace_client):
Expand All @@ -64,8 +64,8 @@ def test_poll_for_result_completed_with_query(genie, mock_workspace_client):
}
},
]
result = genie.poll_for_result("123", "456")
assert result == pd.DataFrame().to_markdown()
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result == pd.DataFrame().to_markdown()


def test_poll_for_result_executing_query(genie, mock_workspace_client):
Expand All @@ -84,32 +84,32 @@ def test_poll_for_result_executing_query(genie, mock_workspace_client):
}
},
]
result = genie.poll_for_result("123", "456")
assert result == pd.DataFrame().to_markdown()
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result == pd.DataFrame().to_markdown()


def test_poll_for_result_failed(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "FAILED"},
]
result = genie.poll_for_result("123", "456")
assert result is None
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result is None


def test_poll_for_result_cancelled(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "CANCELLED"},
]
result = genie.poll_for_result("123", "456")
assert result is None
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result is None


def test_poll_for_result_expired(genie, mock_workspace_client):
mock_workspace_client.genie._api.do.side_effect = [
{"status": "QUERY_RESULT_EXPIRED"},
]
result = genie.poll_for_result("123", "456")
assert result is None
genie_result = genie.poll_for_result("123", "456")
assert genie_result.result is None


def test_poll_for_result_max_iterations(genie, mock_workspace_client):
Expand Down Expand Up @@ -148,8 +148,8 @@ def test_ask_question(genie, mock_workspace_client):
{"conversation_id": "123", "message_id": "456"},
{"status": "COMPLETED", "attachments": [{"text": {"content": "Answer"}}]},
]
result = genie.ask_question("What is the meaning of life?")
assert result == "Answer"
genie_result = genie.ask_question("What is the meaning of life?")
assert genie_result.result == "Answer"


def test_parse_query_result_empty():
Expand Down

0 comments on commit f815b62

Please sign in to comment.