Skip to content

Commit

Permalink
fix(prompt-lab): prevent crash when sending to LLM before rendering
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Oct 2, 2024
1 parent 0c808cd commit 592ef56
Showing 1 changed file with 48 additions and 41 deletions.
89 changes: 48 additions & 41 deletions packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from dataclasses import dataclass, field, replace
from typing import Any

import gradio as gr
Expand All @@ -12,6 +13,7 @@
from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery


@dataclass(frozen=True)
class PromptState:
"""
Class to store the current state of the application.
Expand All @@ -20,28 +22,30 @@ class PromptState:
Attributes:
prompts (list): A list containing discovered prompts.
variable_values (dict): A dictionary to store values entered by the user for prompt input fields.
dynamic_tb (dict): A dictionary containing dynamically created textboxes based on prompt input fields.
current_prompt (Prompt): The currently processed Prompt object. This is created upon clicking the
"Render Prompt" button and reflects in the "Rendered Prompt" field.
It is used for communication with the LLM.
rendered_prompt (Prompt): The most recently rendered Prompt instance.
llm_model_name (str): The name of the selected LLM model.
llm_api_key (str | None): The API key for the chosen LLM model.
temp_field_name (str): Temporary field name used internally.
"""

prompts: list = []
variable_values: dict = {}
dynamic_tb: dict = {}
current_prompt: Prompt | None = None
prompts: list = field(default_factory=list)
rendered_prompt: Prompt | None = None
llm_model_name: str | None = None
llm_api_key: str | None = None
temp_field_name: str = ""

def copy(self, **kwargs: Any) -> "PromptState":
"""
Creates a copy of the current state with updated values.
def render_prompt(
index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any
) -> tuple[str, str, gr.State]:
Args:
**kwargs (Any): The updated values to be applied to the copied state.
Returns:
PromptState: A copy of the current state with the updated values.
"""
return replace(self, **kwargs)


def render_prompt(index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any) -> gr.State:
"""
Renders a prompt based on the provided key, system prompt, user prompt, and input variables.
Expand All @@ -56,21 +60,20 @@ def render_prompt(
args (tuple): A tuple of input values for the prompt.
Returns:
tuple[str, str, PromptState]: A tuple containing the rendered system prompt, rendered user prompt,
and the updated application state.
gr.State: The updated application state object.
"""
variables = dict(zip(state.dynamic_tb.keys(), args))

prompt_class = state.prompts[index]
prompt_class.system_prompt_template = jinja2.Template(system_prompt)
prompt_class.user_prompt_template = jinja2.Template(user_prompt)

input_type = prompt_class.input_type
input_fields = get_input_type_fields(input_type)
variables = {field["field_name"]: value for field, value in zip(input_fields, args)}
input_data = input_type(**variables) if input_type is not None else None
prompt_object = prompt_class(input_data=input_data)
state.current_prompt = prompt_object
state = state.copy(rendered_prompt=prompt_object)

return prompt_object.system_message, prompt_object.user_message, state
return state


def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]:
Expand Down Expand Up @@ -102,12 +105,12 @@ def send_prompt_to_llm(state: gr.State) -> str:
Returns:
str: The response generated by the LLM.
"""
current_prompt = state.current_prompt

llm_client = LiteLLM(model_name=state.llm_model_name, api_key=state.llm_api_key)

try:
response = asyncio.run(llm_client.client.call(conversation=current_prompt.chat, options=LiteLLMOptions()))
response = asyncio.run(
llm_client.client.call(conversation=state.rendered_prompt.chat, options=LiteLLMOptions())
)
except Exception as e: # pylint: disable=broad-except
response = str(e)

Expand Down Expand Up @@ -154,12 +157,14 @@ def lab_app( # pylint: disable=missing-param-doc
return

with gr.Blocks() as gr_app:
prompt_state_obj = PromptState()
prompt_state_obj.llm_model_name = llm_model
prompt_state_obj.llm_api_key = llm_api_key
prompt_state_obj.prompts = list(prompts)
prompts_state = gr.State(
PromptState(
llm_model_name=llm_model,
llm_api_key=llm_api_key,
prompts=list(prompts),
)
)

prompts_state = gr.State(value=prompt_state_obj)
prompt_selection_dropdown = gr.Dropdown(
choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt"
)
Expand All @@ -172,7 +177,6 @@ def show_split(index: int, state: gr.State) -> None:
with gr.Column(scale=1):
with gr.Tab("Inputs"):
input_fields: list = get_input_type_fields(prompt.input_type)
tb_dict = {}
for entry in input_fields:
with gr.Row():
var = gr.Textbox(
Expand All @@ -182,10 +186,6 @@ def show_split(index: int, state: gr.State) -> None:
)
list_of_vars.append(var)

tb_dict[entry["field_name"]] = var

state.dynamic_tb = tb_dict

render_prompt_button = gr.Button(value="Render prompts")

with gr.Column(scale=4):
Expand All @@ -197,9 +197,8 @@ def show_split(index: int, state: gr.State) -> None:
)

with gr.Column():
prompt_details_system_prompt_rendered = gr.Textbox(
label="Rendered System Prompt", value="", interactive=False
)
system_message = state.rendered_prompt.system_message if state.rendered_prompt else ""
gr.Textbox(label="Rendered System Prompt", value=system_message, interactive=False)

with gr.Row():
with gr.Column():
Expand All @@ -208,12 +207,19 @@ def show_split(index: int, state: gr.State) -> None:
)

with gr.Column():
prompt_details_user_prompt_rendered = gr.Textbox(
label="Rendered User Prompt", value="", interactive=False
)
user_message = state.rendered_prompt.user_message if state.rendered_prompt else ""
gr.Textbox(label="Rendered User Prompt", value=user_message, interactive=False)

llm_enabled = state.llm_model_name is not None
llm_request_button = gr.Button(value="Send to LLM", interactive=llm_enabled)
prompt_ready = state.rendered_prompt is not None
llm_request_button = gr.Button(
value="Send to LLM",
interactive=llm_enabled and prompt_ready,
)
gr.Markdown(
"To enable this button, select an LLM model when starting the app in CLI.", visible=not llm_enabled
)
gr.Markdown("To enable this button, render a prompt first.", visible=llm_enabled and not prompt_ready)
llm_prompt_response = gr.Textbox(lines=10, label="LLM response")

render_prompt_button.click(
Expand All @@ -225,8 +231,9 @@ def show_split(index: int, state: gr.State) -> None:
prompts_state,
*list_of_vars,
],
[prompt_details_system_prompt_rendered, prompt_details_user_prompt_rendered, prompts_state],
[prompts_state],
)
llm_request_button.click(send_prompt_to_llm, prompts_state, llm_prompt_response)
prompt_selection_dropdown.change(list_prompt_choices, prompts_state)

gr_app.launch()

0 comments on commit 592ef56

Please sign in to comment.