Skip to content

Commit

Permalink
Genie tracing
Browse files Browse the repository at this point in the history
Signed-off-by: Prithvi Kannan <[email protected]>
  • Loading branch information
prithvikannan committed Dec 12, 2024
1 parent c9946ad commit cda59a7
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion integrations/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pip install databricks-langchain
### Install from source

```sh
pip install git+ssh://[email protected]/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain
pip install git+https://[email protected]/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain
```

## Get started
Expand Down
5 changes: 3 additions & 2 deletions integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from databricks_ai_bridge.genie import Genie
import mlflow


@mlflow.trace()
def _concat_messages_array(messages):
concatenated_message = "\n".join(
[
Expand All @@ -12,7 +13,7 @@ def _concat_messages_array(messages):
)
return concatenated_message


@mlflow.trace()
def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
from langchain_core.messages import AIMessage

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"pandas",
"tiktoken>=0.8.0",
"tabulate",
"mlflow",
]

[project.license]
Expand Down
5 changes: 4 additions & 1 deletion src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import tiktoken
from databricks.sdk import WorkspaceClient

import mlflow

MAX_TOKENS_OF_DATA = 20000 # max tokens of data in markdown format
MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS

Expand All @@ -16,7 +18,7 @@ def _count_tokens(text):
encoding = tiktoken.encoding_for_model("gpt-4o")
return len(encoding.encode(text))


@mlflow.trace()
def _parse_query_result(resp) -> Union[str, pd.DataFrame]:
columns = resp["manifest"]["schema"]["columns"]
header = [str(col["name"]) for col in columns]
Expand Down Expand Up @@ -92,6 +94,7 @@ def create_message(self, conversation_id, content):
return resp

def poll_for_result(self, conversation_id, message_id):
@mlflow.trace()
def poll_result():
iteration_count = 0
while iteration_count < MAX_ITERATIONS:
Expand Down

0 comments on commit cda59a7

Please sign in to comment.