Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Genie tracing #32

Merged
merged 8 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pip install databricks-langchain
### Install from source

```sh
pip install git+ssh://[email protected]/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain
pip install git+https://[email protected]/databricks/databricks-ai-bridge.git#subdirectory=integrations/langchain
```

## Get started
Expand Down
5 changes: 3 additions & 2 deletions integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from databricks_ai_bridge.genie import Genie
import mlflow


@mlflow.trace()
def _concat_messages_array(messages):
concatenated_message = "\n".join(
[
Expand All @@ -12,7 +13,7 @@ def _concat_messages_array(messages):
)
return concatenated_message


@mlflow.trace()
def _query_genie_as_agent(input, genie_space_id, genie_agent_name):
from langchain_core.messages import AIMessage

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"pandas",
"tiktoken>=0.8.0",
"tabulate",
"mlflow",
prithvikannan marked this conversation as resolved.
Show resolved Hide resolved
]

[project.license]
Expand Down
5 changes: 4 additions & 1 deletion src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import tiktoken
from databricks.sdk import WorkspaceClient

import mlflow

MAX_TOKENS_OF_DATA = 20000 # max tokens of data in markdown format
MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS

Expand All @@ -16,7 +18,7 @@ def _count_tokens(text):
encoding = tiktoken.encoding_for_model("gpt-4o")
return len(encoding.encode(text))


@mlflow.trace()
def _parse_query_result(resp) -> Union[str, pd.DataFrame]:
columns = resp["manifest"]["schema"]["columns"]
header = [str(col["name"]) for col in columns]
Expand Down Expand Up @@ -92,6 +94,7 @@ def create_message(self, conversation_id, content):
return resp

def poll_for_result(self, conversation_id, message_id):
@mlflow.trace()
def poll_result():
iteration_count = 0
while iteration_count < MAX_ITERATIONS:
Expand Down
Loading