From 37ad67569055300cb3950c99276588c407711f21 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 17 Dec 2024 15:59:03 -0800 Subject: [PATCH] fix Signed-off-by: Prithvi Kannan --- integrations/langchain/src/databricks_langchain/genie.py | 3 ++- pyproject.toml | 2 +- src/databricks_ai_bridge/genie.py | 7 +++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 088e85c..1b7c1f1 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -1,5 +1,6 @@ -from databricks_ai_bridge.genie import Genie import mlflow +from databricks_ai_bridge.genie import Genie + @mlflow.trace() def _concat_messages_array(messages): diff --git a/pyproject.toml b/pyproject.toml index 0ee60ec..9d68541 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "pandas", "tiktoken>=0.8.0", "tabulate", - "mlflow", + "mlflow-skinny", ] [project.license] diff --git a/src/databricks_ai_bridge/genie.py b/src/databricks_ai_bridge/genie.py index 0023929..8715b9e 100644 --- a/src/databricks_ai_bridge/genie.py +++ b/src/databricks_ai_bridge/genie.py @@ -3,12 +3,11 @@ from datetime import datetime from typing import Union +import mlflow import pandas as pd import tiktoken from databricks.sdk import WorkspaceClient -import mlflow - MAX_TOKENS_OF_DATA = 20000 # max tokens of data in markdown format MAX_ITERATIONS = 50 # max times to poll the API when polling for either result or the query results, each iteration is ~1 second, so max latency == 2 * MAX_ITERATIONS @@ -75,6 +74,7 @@ def __init__(self, space_id): "Content-Type": "application/json", } + @mlflow.trace() def start_conversation(self, content): resp = self.genie._api.do( "POST", @@ -84,6 +84,7 @@ def start_conversation(self, content): ) return resp + @mlflow.trace() def create_message(self, conversation_id, content): resp = self.genie._api.do( "POST", @@ -93,6 +94,7 @@ def create_message(self, conversation_id, content): ) return resp + @mlflow.trace() def poll_for_result(self, conversation_id, message_id): @mlflow.trace() def poll_result(): @@ -139,6 +141,7 @@ def poll_result(): logging.debug(f"Waiting...: {resp['status']}") time.sleep(5) + @mlflow.trace() def poll_query_results(): iteration_count = 0 while iteration_count < MAX_ITERATIONS: