From 2857978dff681cd583a782a13002aafe0dca4f85 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 3 Oct 2024 11:18:41 +0200 Subject: [PATCH] fix(prompt-lab): prevent crash when sending to LLM before rendering (#57) --- .../src/ragbits/dev_kit/prompt_lab/app.py | 76 +++++++++---------- 1 file changed, 35 insertions(+), 41 deletions(-) diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py index a6db0b904..8b494ae4c 100644 --- a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py @@ -1,4 +1,5 @@ import asyncio +from dataclasses import dataclass, field, replace from typing import Any import gradio as gr @@ -12,6 +13,7 @@ from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery +@dataclass(frozen=True) class PromptState: """ Class to store the current state of the application. @@ -20,28 +22,18 @@ class PromptState: Attributes: prompts (list): A list containing discovered prompts. - variable_values (dict): A dictionary to store values entered by the user for prompt input fields. - dynamic_tb (dict): A dictionary containing dynamically created textboxes based on prompt input fields. - current_prompt (Prompt): The currently processed Prompt object. This is created upon clicking the - "Render Prompt" button and reflects in the "Rendered Prompt" field. - It is used for communication with the LLM. + rendered_prompt (Prompt): The most recently rendered Prompt instance. llm_model_name (str): The name of the selected LLM model. llm_api_key (str | None): The API key for the chosen LLM model. - temp_field_name (str): Temporary field name used internally. """ - prompts: list = [] - variable_values: dict = {} - dynamic_tb: dict = {} - current_prompt: Prompt | None = None + prompts: list = field(default_factory=list) + rendered_prompt: Prompt | None = None llm_model_name: str | None = None llm_api_key: str | None = None - temp_field_name: str = "" -def render_prompt( - index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any -) -> tuple[str, str, gr.State]: +def render_prompt(index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any) -> gr.State: """ Renders a prompt based on the provided key, system prompt, user prompt, and input variables. @@ -56,21 +48,20 @@ def render_prompt( args (tuple): A tuple of input values for the prompt. Returns: - tuple[str, str, PromptState]: A tuple containing the rendered system prompt, rendered user prompt, - and the updated application state. + gr.State: The updated application state object. """ - variables = dict(zip(state.dynamic_tb.keys(), args)) - prompt_class = state.prompts[index] prompt_class.system_prompt_template = jinja2.Template(system_prompt) prompt_class.user_prompt_template = jinja2.Template(user_prompt) input_type = prompt_class.input_type + input_fields = get_input_type_fields(input_type) + variables = {field["field_name"]: value for field, value in zip(input_fields, args)} input_data = input_type(**variables) if input_type is not None else None prompt_object = prompt_class(input_data=input_data) - state.current_prompt = prompt_object + state = replace(state, rendered_prompt=prompt_object) - return prompt_object.system_message, prompt_object.user_message, state + return state def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]: @@ -102,12 +93,12 @@ def send_prompt_to_llm(state: gr.State) -> str: Returns: str: The response generated by the LLM. """ - current_prompt = state.current_prompt - llm_client = LiteLLM(model_name=state.llm_model_name, api_key=state.llm_api_key) try: - response = asyncio.run(llm_client.client.call(conversation=current_prompt.chat, options=LiteLLMOptions())) + response = asyncio.run( + llm_client.client.call(conversation=state.rendered_prompt.chat, options=LiteLLMOptions()) + ) except Exception as e: # pylint: disable=broad-except response = str(e) @@ -154,12 +145,14 @@ def lab_app( # pylint: disable=missing-param-doc return with gr.Blocks() as gr_app: - prompt_state_obj = PromptState() - prompt_state_obj.llm_model_name = llm_model - prompt_state_obj.llm_api_key = llm_api_key - prompt_state_obj.prompts = list(prompts) + prompts_state = gr.State( + PromptState( + llm_model_name=llm_model, + llm_api_key=llm_api_key, + prompts=list(prompts), + ) + ) - prompts_state = gr.State(value=prompt_state_obj) prompt_selection_dropdown = gr.Dropdown( choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt" ) @@ -172,7 +165,6 @@ def show_split(index: int, state: gr.State) -> None: with gr.Column(scale=1): with gr.Tab("Inputs"): input_fields: list = get_input_type_fields(prompt.input_type) - tb_dict = {} for entry in input_fields: with gr.Row(): var = gr.Textbox( @@ -182,10 +174,6 @@ def show_split(index: int, state: gr.State) -> None: ) list_of_vars.append(var) - tb_dict[entry["field_name"]] = var - - state.dynamic_tb = tb_dict - render_prompt_button = gr.Button(value="Render prompts") with gr.Column(scale=4): @@ -197,9 +185,8 @@ def show_split(index: int, state: gr.State) -> None: ) with gr.Column(): - prompt_details_system_prompt_rendered = gr.Textbox( - label="Rendered System Prompt", value="", interactive=False - ) + system_message = state.rendered_prompt.system_message if state.rendered_prompt else "" + gr.Textbox(label="Rendered System Prompt", value=system_message, interactive=False) with gr.Row(): with gr.Column(): @@ -208,12 +195,19 @@ def show_split(index: int, state: gr.State) -> None: ) with gr.Column(): - prompt_details_user_prompt_rendered = gr.Textbox( - label="Rendered User Prompt", value="", interactive=False - ) + user_message = state.rendered_prompt.user_message if state.rendered_prompt else "" + gr.Textbox(label="Rendered User Prompt", value=user_message, interactive=False) llm_enabled = state.llm_model_name is not None - llm_request_button = gr.Button(value="Send to LLM", interactive=llm_enabled) + prompt_ready = state.rendered_prompt is not None + llm_request_button = gr.Button( + value="Send to LLM", + interactive=llm_enabled and prompt_ready, + ) + gr.Markdown( + "To enable this button, select an LLM model when starting the app in CLI.", visible=not llm_enabled + ) + gr.Markdown("To enable this button, render a prompt first.", visible=llm_enabled and not prompt_ready) llm_prompt_response = gr.Textbox(lines=10, label="LLM response") render_prompt_button.click( @@ -225,7 +219,7 @@ def show_split(index: int, state: gr.State) -> None: prompts_state, *list_of_vars, ], - [prompt_details_system_prompt_rendered, prompt_details_user_prompt_rendered, prompts_state], + [prompts_state], ) llm_request_button.click(send_prompt_to_llm, prompts_state, llm_prompt_response)