diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index 540d97e..228ae3c 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -54,18 +54,11 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]: query_result = pd.DataFrame(rows, columns=header).to_markdown() # trim down from the total rows until we get under the token limit - trimmed_rows = len(rows) tokens_used = _count_tokens(query_result) - while trimmed_rows > 0 and tokens_used > MAX_TOKENS_OF_DATA: - # convert to markdown - query_result = pd.DataFrame(rows, columns=header).head(trimmed_rows).to_markdown() - # keep trimming down until we get under the token limit - trimmed_rows -= 5 - # worst case, return None, which the Agent will handle and not display the query results + while tokens_used > MAX_TOKENS_OF_DATA: + rows.pop() + query_result = pd.DataFrame(rows, columns=header).to_markdown() tokens_used = _count_tokens(query_result) - if trimmed_rows == 0: - query_result = None - tokens_used = 0 return query_result.strip() if query_result else query_result @@ -102,6 +95,7 @@ def poll_for_result(self, conversation_id, message_id): def poll_result(): iteration_count = 0 while True and iteration_count < MAX_ITERATIONS: + iteration_count += 1 resp = self.genie._api.do( "GET", f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}", @@ -126,6 +120,7 @@ def poll_result(): def poll_query_results(): iteration_count = 0 while True and iteration_count < MAX_ITERATIONS: + iteration_count += 1 resp = self.genie._api.do( "GET", f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/query-result", diff --git a/tests/databricks_ai_bridge/test_genie.py b/tests/databricks_ai_bridge/test_genie.py index 0d1cc2f..3957a32 100644 --- a/tests/databricks_ai_bridge/test_genie.py +++ b/tests/databricks_ai_bridge/test_genie.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from databricks_ai_bridge.genie import Genie, _parse_query_result +from databricks_ai_bridge.genie import Genie, _parse_query_result, _count_tokens @pytest.fixture @@ -67,6 +67,38 @@ def test_poll_for_result_executing_query(genie, mock_workspace_client): result = genie.poll_for_result("123", "456") assert result == pd.DataFrame().to_markdown() +def test_poll_for_result_failed(genie, mock_workspace_client): + mock_workspace_client.genie._api.do.side_effect = [ + {"status": "FAILED"}, + ] + result = genie.poll_for_result("123", "456") + assert result is None + + +def test_poll_for_result_max_iterations(genie, mock_workspace_client): + # patch MAX_ITERATIONS to 2 for this test and sleep to avoid delays + with patch("databricks_ai_bridge.genie.MAX_ITERATIONS", 2), \ + patch("time.sleep", return_value=None): + mock_workspace_client.genie._api.do.side_effect = [ + {"status": "EXECUTING_QUERY", "attachments": [{"query": {"query": "SELECT *"}}]}, + { + "statement_response": { + "status": {"state": "RUNNING"}, + } + }, + { + "statement_response": { + "status": {"state": "RUNNING"}, + } + }, + { + "statement_response": { + "status": {"state": "RUNNING"}, + } + } + ] + result = genie.poll_for_result("123", "456") + assert result is None def test_ask_question(genie, mock_workspace_client): mock_workspace_client.genie._api.do.side_effect = [ @@ -139,3 +171,45 @@ def test_parse_query_result_with_null_values(): } ) assert result == expected_df.to_markdown() + +def test_parse_query_result_trims_large_data(): + # patch MAX_TOKENS_OF_DATA to 100 for this test + with patch("databricks_ai_bridge.genie.MAX_TOKENS_OF_DATA", 100): + resp = { + "manifest": { + "schema": { + "columns": [ + {"name": "id", "type_name": "INT"}, + {"name": "name", "type_name": "STRING"}, + {"name": "created_at", "type_name": "TIMESTAMP"}, + ] + } + }, + "result": { + "data_typed_array": [ + {"values": [{"str": "1"}, {"str": "Alice"}, {"str": "2023-10-01T00:00:00Z"}]}, + {"values": [{"str": "2"}, {"str": "Bob"}, {"str": "2023-10-02T00:00:00Z"}]}, + {"values": [{"str": "3"}, {"str": "Charlie"}, {"str": "2023-10-03T00:00:00Z"}]}, + {"values": [{"str": "4"}, {"str": "David"}, {"str": "2023-10-04T00:00:00Z"}]}, + {"values": [{"str": "5"}, {"str": "Eve"}, {"str": "2023-10-05T00:00:00Z"}]}, + {"values": [{"str": "6"}, {"str": "Frank"}, {"str": "2023-10-06T00:00:00Z"}]}, + {"values": [{"str": "7"}, {"str": "Grace"}, {"str": "2023-10-07T00:00:00Z"}]}, + {"values": [{"str": "8"}, {"str": "Hank"}, {"str": "2023-10-08T00:00:00Z"}]}, + {"values": [{"str": "9"}, {"str": "Ivy"}, {"str": "2023-10-09T00:00:00Z"}]}, + {"values": [{"str": "10"}, {"str": "Jack"}, {"str": "2023-10-10T00:00:00Z"}]}, + ] + }, + } + result = _parse_query_result(resp) + assert result == pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "created_at": [ + datetime(2023, 10, 1).date(), + datetime(2023, 10, 2).date(), + datetime(2023, 10, 3).date(), + ] + } + ).to_markdown() + assert _count_tokens(result) <= 100 \ No newline at end of file