Skip to content

Commit

Permalink
Updated sql migration assistant UI (#260)
Browse files Browse the repository at this point in the history
Revamps the UI for the migration assistant, adds workflow automation
  • Loading branch information
robertwhiffin authored Oct 2, 2024
1 parent 3b1fdf2 commit 7348f7d
Show file tree
Hide file tree
Showing 19 changed files with 1,639 additions and 299 deletions.
1 change: 1 addition & 0 deletions sql_migration_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def hello():
w = WorkspaceClient(product="sql_migration_assistant", product_version="0.0.1")
p = Prompts()
setter_upper = SetUpMigrationAssistant()
setter_upper.check_cloud(w)
final_config = setter_upper.setup_migration_assistant(w, p)
current_path = Path(__file__).parent.resolve()

Expand Down
38 changes: 21 additions & 17 deletions sql_migration_assistant/app/llm.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
import logging
import gradio as gr

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole

w = WorkspaceClient()
foundation_llm_name = "databricks-meta-llama-3-1-405b-instruct"
max_token = 4096
messages = [
ChatMessage(role=ChatMessageRole.SYSTEM, content="You are an unhelpful assistant"),
ChatMessage(role=ChatMessageRole.USER, content="What is RAG?"),
]


class LLMCalls:
def __init__(self, foundation_llm_name, max_tokens):
def __init__(self, foundation_llm_name):
self.w = WorkspaceClient()
self.foundation_llm_name = foundation_llm_name
self.max_tokens = int(max_tokens)

def call_llm(self, messages):
def call_llm(self, messages, max_tokens, temperature):
"""
Function to call the LLM model and return the response.
:param messages: list of messages like
Expand All @@ -29,8 +20,17 @@ def call_llm(self, messages):
]
:return: the response from the model
"""

max_tokens = int(max_tokens)
temperature = float(temperature)
# check to make sure temperature is between 0.0 and 1.0
if temperature < 0.0 or temperature > 1.0:
raise gr.Error("Temperature must be between 0.0 and 1.0")
response = self.w.serving_endpoints.query(
name=foundation_llm_name, max_tokens=max_token, messages=messages
name=self.foundation_llm_name,
max_tokens=max_tokens,
messages=messages,
temperature=temperature,
)
message = response.choices[0].message.content
return message
Expand All @@ -53,14 +53,16 @@ def convert_chat_to_llm_input(self, system_prompt, chat):

# this is called to actually send a request and receive response from the llm endpoint.

def llm_translate(self, system_prompt, input_code):
def llm_translate(self, system_prompt, input_code, max_tokens, temperature):
messages = [
ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt),
ChatMessage(role=ChatMessageRole.USER, content=input_code),
]

# call the LLM end point.
llm_answer = self.call_llm(messages=messages)
llm_answer = self.call_llm(
messages=messages, max_tokens=max_tokens, temperature=temperature
)
# Extract the code from in between the triple backticks (```), since LLM often prints the code like this.
# Also removes the 'sql' prefix always added by the LLM.
translation = llm_answer # .split("Final answer:\n")[1].replace(">>", "").replace("<<", "")
Expand All @@ -73,12 +75,14 @@ def llm_chat(self, system_prompt, query, chat_history):
llm_answer = self.call_llm(messages=messages)
return llm_answer

def llm_intent(self, system_prompt, input_code):
def llm_intent(self, system_prompt, input_code, max_tokens, temperature):
messages = [
ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt),
ChatMessage(role=ChatMessageRole.USER, content=input_code),
]

# call the LLM end point.
llm_answer = self.call_llm(messages=messages)
llm_answer = self.call_llm(
messages=messages, max_tokens=max_tokens, temperature=temperature
)
return llm_answer
Loading

0 comments on commit 7348f7d

Please sign in to comment.