Skip to content

Commit

Permalink
fix(prompt-lab): app shouldn't crash when no prompts found (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer authored Oct 2, 2024
1 parent fdc5790 commit 877b35f
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gradio as gr
import jinja2
from pydantic import BaseModel
from rich.console import Console

from ragbits.core.llms import LiteLLM
from ragbits.core.llms.clients import LiteLLMOptions
Expand Down Expand Up @@ -38,28 +39,6 @@ class PromptState:
temp_field_name: str = ""


def load_prompts_list(pattern: str, state: gr.State) -> gr.State:
"""
Fetches a list of prompts based on provided paths and updates the application state.
This function takes a path-pattern for discovering prompt definition files and uses the
PromptDiscovery class to discover prompts within those files. The discovered prompts are then
stored in the application state object.
Args:
pattern (str): A pattern for looking up prompt files.
state (gr.State): The Gradio state object to update with discovered prompts.
Returns:
gr.State: The updated Gradio state object containing the list of discovered prompts.
"""
obj = PromptDiscovery(file_pattern=pattern)
discovered_prompts = list(obj.discover())
state.value.prompts = discovered_prompts

return state


def render_prompt(
index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any
) -> tuple[str, str, gr.State]:
Expand Down Expand Up @@ -91,9 +70,7 @@ def render_prompt(
prompt_object = prompt_class(input_data=input_data)
state.current_prompt = prompt_object

chat_dict = {entry["role"]: entry["content"] for entry in prompt_object.chat}

return chat_dict["system"], chat_dict["user"], state
return prompt_object.system_message, prompt_object.user_message, state


def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]:
Expand Down Expand Up @@ -165,13 +142,24 @@ def lab_app( # pylint: disable=missing-param-doc
Launches the interactive application for listing, rendering, and testing prompts
defined within the current project.
"""
prompts = PromptDiscovery(file_pattern=file_pattern).discover()

if not prompts:
Console(stderr=True).print(
f"""No prompts were found for the given file pattern: [b]{file_pattern}[/b].
Please make sure that you are executing the command from the correct directory \
or provide a custom file pattern using the [b]--file-pattern[/b] flag."""
)
return

with gr.Blocks() as gr_app:
prompt_state_obj = PromptState()
prompt_state_obj.llm_model_name = llm_model
prompt_state_obj.llm_api_key = llm_api_key
prompt_state_obj.prompts = list(prompts)

prompts_state = gr.State(value=prompt_state_obj)
prompts_state = load_prompts_list(pattern=file_pattern, state=prompts_state)
prompt_selection_dropdown = gr.Dropdown(
choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt"
)
Expand Down

0 comments on commit 877b35f

Please sign in to comment.