Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Prithvi Kannan <[email protected]>
  • Loading branch information
prithvikannan committed Dec 17, 2024
1 parent cda59a7 commit 37ad675
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion integrations/langchain/src/databricks_langchain/genie.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from databricks_ai_bridge.genie import Genie
import mlflow
from databricks_ai_bridge.genie import Genie


@mlflow.trace()
def _concat_messages_array(messages):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"pandas",
"tiktoken>=0.8.0",
"tabulate",
"mlflow",
"mlflow-skinny",
]

[project.license]
Expand Down
7 changes: 5 additions & 2 deletions src/databricks_ai_bridge/genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from datetime import datetime
from typing import Union

import mlflow
import pandas as pd
import tiktoken
from databricks.sdk import WorkspaceClient

import mlflow

MAX_TOKENS_OF_DATA = 20000 # max tokens of data in markdown format
MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS

Expand Down Expand Up @@ -75,6 +74,7 @@ def __init__(self, space_id):
"Content-Type": "application/json",
}

@mlflow.trace()
def start_conversation(self, content):
resp = self.genie._api.do(
"POST",
Expand All @@ -84,6 +84,7 @@ def start_conversation(self, content):
)
return resp

@mlflow.trace()
def create_message(self, conversation_id, content):
resp = self.genie._api.do(
"POST",
Expand All @@ -93,6 +94,7 @@ def create_message(self, conversation_id, content):
)
return resp

@mlflow.trace()
def poll_for_result(self, conversation_id, message_id):
@mlflow.trace()
def poll_result():
Expand Down Expand Up @@ -139,6 +141,7 @@ def poll_result():
logging.debug(f"Waiting...: {resp['status']}")
time.sleep(5)

@mlflow.trace()
def poll_query_results():
iteration_count = 0
while iteration_count < MAX_ITERATIONS:
Expand Down

0 comments on commit 37ad675

Please sign in to comment.