Skip to content

Commit

Permalink
allow users to enter custom instructions for llm
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jul 15, 2024
1 parent dda75f6 commit 9036519
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
27 changes: 24 additions & 3 deletions wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@
st.title("WrenAI Data Curation App")


LLM_OPTIONS = ["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o"]

llm_client = get_openai_client()


# session states
if "llm_model" not in st.session_state:
st.session_state["llm_model"] = None
st.session_state["llm_model"] = LLM_OPTIONS[0]
if "deployment_id" not in st.session_state:
st.session_state["deployment_id"] = str(uuid.uuid4())
if "mdl_json" not in st.session_state:
st.session_state["mdl_json"] = None
if "custom_instructions_for_llm" not in st.session_state:
st.session_state["custom_instructions_for_llm"] = ""
if "llm_question_sql_pairs" not in st.session_state:
st.session_state["llm_question_sql_pairs"] = []
if "user_question_sql_pair" not in st.session_state:
Expand All @@ -54,11 +58,20 @@ def on_change_upload_eval_dataset():
st.session_state["candidate_dataset"] = doc["eval_dataset"]


def on_change_custom_instructions_for_llm():
st.session_state["custom_instructions_for_llm"] = st.session_state[
"custom_instructions_text_area"
]


def on_click_generate_question_sql_pairs(llm_client: AsyncClient):
st.toast("Generating question-sql-pairs...")
st.session_state["llm_question_sql_pairs"] = asyncio.run(
get_question_sql_pairs(
llm_client, st.session_state["llm_model"], st.session_state["mdl_json"]
llm_client,
st.session_state["llm_model"],
st.session_state["mdl_json"],
st.session_state["custom_instructions_for_llm"],
)
)

Expand Down Expand Up @@ -178,7 +191,7 @@ def on_click_remove_candidate_dataset_button(i: int):

st.selectbox(
label="Select which LLM model you want to use",
options=["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o"],
options=LLM_OPTIONS,
index=0,
key="select_llm_model",
on_change=on_change_llm_model,
Expand Down Expand Up @@ -229,6 +242,14 @@ def on_click_remove_candidate_dataset_button(i: int):
)

with tab_generated_by_llm:
st.text_area(
"Custom Instructions for generating question-sql-pairs (Optional)",
key="custom_instructions_text_area",
value=st.session_state["custom_instructions_for_llm"],
placeholder="You can specify the custom instructions on how LLM should generate question-sql-pairs here, for example: what type of questions you want to generate.",
on_change=on_change_custom_instructions_for_llm,
)

st.button(
"Generate 10 question-sql-pairs",
key="generate_question_sql_pairs",
Expand Down
20 changes: 14 additions & 6 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]):
if "sortings" in result:
contexts += _compose_contexts_of_sorting_type(result["sortings"])

print(
f'SQL ANALYSIS RESULTS: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode("utf-8")}'
)
print(f"CONTEXTS: {sorted(set(contexts))}")
print("\n\n")
# print(
# f'SQL ANALYSIS RESULTS: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode("utf-8")}'
# )
# print(f"CONTEXTS: {sorted(set(contexts))}")
# print("\n\n")

return sorted(set(contexts))

Expand All @@ -301,7 +301,11 @@ def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]):


async def get_question_sql_pairs(
llm_client: AsyncClient, llm_model: str, mdl_json: dict, num_pairs: int = 10
llm_client: AsyncClient,
llm_model: str,
mdl_json: dict,
custom_instructions: str,
num_pairs: int = 10,
) -> list[dict]:
messages = [
{
Expand Down Expand Up @@ -329,6 +333,10 @@ async def get_question_sql_pairs(
]
}}
### Custom Instructions ###
{custom_instructions}
### Input ###
Data Model: {get_ddl_commands(mdl_json)}
Expand Down

0 comments on commit 9036519

Please sign in to comment.