Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(prompt-lab): prevent crash when sending to LLM before rendering #57

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 48 additions & 41 deletions packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from dataclasses import dataclass, field, replace
from typing import Any

import gradio as gr
Expand All @@ -12,6 +13,7 @@
from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery


@dataclass(frozen=True)
class PromptState:
"""
Class to store the current state of the application.
Expand All @@ -20,28 +22,30 @@ class PromptState:

Attributes:
prompts (list): A list containing discovered prompts.
variable_values (dict): A dictionary to store values entered by the user for prompt input fields.
dynamic_tb (dict): A dictionary containing dynamically created textboxes based on prompt input fields.
current_prompt (Prompt): The currently processed Prompt object. This is created upon clicking the
"Render Prompt" button and reflects in the "Rendered Prompt" field.
It is used for communication with the LLM.
rendered_prompt (Prompt): The most recently rendered Prompt instance.
llm_model_name (str): The name of the selected LLM model.
llm_api_key (str | None): The API key for the chosen LLM model.
temp_field_name (str): Temporary field name used internally.
"""

prompts: list = []
variable_values: dict = {}
dynamic_tb: dict = {}
current_prompt: Prompt | None = None
prompts: list = field(default_factory=list)
rendered_prompt: Prompt | None = None
llm_model_name: str | None = None
llm_api_key: str | None = None
temp_field_name: str = ""

def copy(self, **kwargs: Any) -> "PromptState":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh, for me using replace directly feels better.

Also copy suggest exact copy, and here we do copy and update, so it's confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll switch to using replace directly

"""
Creates a copy of the current state with updated values.

def render_prompt(
index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any
) -> tuple[str, str, gr.State]:
Args:
**kwargs (Any): The updated values to be applied to the copied state.

Returns:
PromptState: A copy of the current state with the updated values.
"""
return replace(self, **kwargs)


def render_prompt(index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any) -> gr.State:
"""
Renders a prompt based on the provided key, system prompt, user prompt, and input variables.

Expand All @@ -56,21 +60,20 @@ def render_prompt(
args (tuple): A tuple of input values for the prompt.

Returns:
tuple[str, str, PromptState]: A tuple containing the rendered system prompt, rendered user prompt,
and the updated application state.
gr.State: The updated application state object.
"""
variables = dict(zip(state.dynamic_tb.keys(), args))

prompt_class = state.prompts[index]
prompt_class.system_prompt_template = jinja2.Template(system_prompt)
prompt_class.user_prompt_template = jinja2.Template(user_prompt)

input_type = prompt_class.input_type
input_fields = get_input_type_fields(input_type)
variables = {field["field_name"]: value for field, value in zip(input_fields, args)}
input_data = input_type(**variables) if input_type is not None else None
prompt_object = prompt_class(input_data=input_data)
state.current_prompt = prompt_object
state = state.copy(rendered_prompt=prompt_object)

return prompt_object.system_message, prompt_object.user_message, state
return state


def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]:
Expand Down Expand Up @@ -102,12 +105,12 @@ def send_prompt_to_llm(state: gr.State) -> str:
Returns:
str: The response generated by the LLM.
"""
current_prompt = state.current_prompt

llm_client = LiteLLM(model_name=state.llm_model_name, api_key=state.llm_api_key)

try:
response = asyncio.run(llm_client.client.call(conversation=current_prompt.chat, options=LiteLLMOptions()))
response = asyncio.run(
llm_client.client.call(conversation=state.rendered_prompt.chat, options=LiteLLMOptions())
)
except Exception as e: # pylint: disable=broad-except
response = str(e)

Expand Down Expand Up @@ -154,12 +157,14 @@ def lab_app( # pylint: disable=missing-param-doc
return

with gr.Blocks() as gr_app:
prompt_state_obj = PromptState()
prompt_state_obj.llm_model_name = llm_model
prompt_state_obj.llm_api_key = llm_api_key
prompt_state_obj.prompts = list(prompts)
prompts_state = gr.State(
PromptState(
llm_model_name=llm_model,
llm_api_key=llm_api_key,
prompts=list(prompts),
)
)

prompts_state = gr.State(value=prompt_state_obj)
prompt_selection_dropdown = gr.Dropdown(
choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt"
)
Expand All @@ -172,7 +177,6 @@ def show_split(index: int, state: gr.State) -> None:
with gr.Column(scale=1):
with gr.Tab("Inputs"):
input_fields: list = get_input_type_fields(prompt.input_type)
tb_dict = {}
for entry in input_fields:
with gr.Row():
var = gr.Textbox(
Expand All @@ -182,10 +186,6 @@ def show_split(index: int, state: gr.State) -> None:
)
list_of_vars.append(var)

tb_dict[entry["field_name"]] = var

state.dynamic_tb = tb_dict

render_prompt_button = gr.Button(value="Render prompts")

with gr.Column(scale=4):
Expand All @@ -197,9 +197,8 @@ def show_split(index: int, state: gr.State) -> None:
)

with gr.Column():
prompt_details_system_prompt_rendered = gr.Textbox(
label="Rendered System Prompt", value="", interactive=False
)
system_message = state.rendered_prompt.system_message if state.rendered_prompt else ""
gr.Textbox(label="Rendered System Prompt", value=system_message, interactive=False)

with gr.Row():
with gr.Column():
Expand All @@ -208,12 +207,19 @@ def show_split(index: int, state: gr.State) -> None:
)

with gr.Column():
prompt_details_user_prompt_rendered = gr.Textbox(
label="Rendered User Prompt", value="", interactive=False
)
user_message = state.rendered_prompt.user_message if state.rendered_prompt else ""
gr.Textbox(label="Rendered User Prompt", value=user_message, interactive=False)

llm_enabled = state.llm_model_name is not None
llm_request_button = gr.Button(value="Send to LLM", interactive=llm_enabled)
prompt_ready = state.rendered_prompt is not None
llm_request_button = gr.Button(
value="Send to LLM",
interactive=llm_enabled and prompt_ready,
)
gr.Markdown(
"To enable this button, select an LLM model when starting the app in CLI.", visible=not llm_enabled
)
gr.Markdown("To enable this button, render a prompt first.", visible=llm_enabled and not prompt_ready)
llm_prompt_response = gr.Textbox(lines=10, label="LLM response")

render_prompt_button.click(
Expand All @@ -225,8 +231,9 @@ def show_split(index: int, state: gr.State) -> None:
prompts_state,
*list_of_vars,
],
[prompt_details_system_prompt_rendered, prompt_details_user_prompt_rendered, prompts_state],
[prompts_state],
)
llm_request_button.click(send_prompt_to_llm, prompts_state, llm_prompt_response)
prompt_selection_dropdown.change(list_prompt_choices, prompts_state)

gr_app.launch()
Loading