Skip to content

Commit

Permalink
feat(prompt-discovery): Look for prompts in arbitrary files using pat…
Browse files Browse the repository at this point in the history
…terns
  • Loading branch information
ludwiktrammer committed Sep 25, 2024
1 parent 7833615 commit 83c0272
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 163 deletions.
140 changes: 62 additions & 78 deletions packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from ragbits.core.llms import LiteLLM
from ragbits.core.llms.clients import LiteLLMOptions
from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import PromptDiscovery
from ragbits.core.prompt import Prompt
from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery


class PromptState:
Expand All @@ -18,8 +19,8 @@ class PromptState:
This class holds various data structures used throughout the application's lifecycle.
Attributes:
prompts_state (dict): A dictionary containing discovered prompts. Keys are prompt names,
values are details about the corresponding Prompt-based class.
prompts (list): A list containing discovered prompts.
prompts_names (list): A list of names of discovered prompts.
variable_values (dict): A dictionary to store values entered by the user for prompt input fields.
dynamic_tb (dict): A dictionary containing dynamically created textboxes based on prompt input fields.
current_prompt (Prompt): The currently processed Prompt object. This is created upon clicking the
Expand All @@ -30,64 +31,42 @@ class PromptState:
temp_field_name (str): Temporary field name used internally.
"""

prompts_state: dict = {}
prompts: list = []
prompts_names: list = []
variable_values: dict = {}
dynamic_tb: dict = {}
current_prompt = None
llm_model_name: str = ""
llm_api_key: str | None = ""
current_prompt: Prompt | None = None
llm_model_name: str | None = None
llm_api_key: str | None = None
temp_field_name: str = ""


def get_prompts_list(path: str, state: gr.State) -> gr.State:
def load_prompts_list(pattern: str, state: gr.State) -> gr.State:
"""
Fetches a list of prompts based on provided paths.
Fetches a list of prompts based on provided paths and updates the application state.
This function takes a comma-separated string of paths to prompt definition files and uses the
This function takes a path-pattern for discovering prompt definition files and uses the
PromptDiscovery class to discover prompts within those files. The discovered prompts are then
stored in the application state object.
Args:
path (str): A comma-separated string of paths to prompt definition files.
pattern (str): A pattern for looking up prompt files.
state (gr.State): The Gradio state object to update with discovered prompts.
Returns:
gr.State: The updated Gradio state object containing the list of discovered prompts.
"""
prompt_paths: list[str] = path.split(",")

obj = PromptDiscovery(file_paths=prompt_paths)
discovered_prompts = obj.discover()
state.value.prompts_state.update(discovered_prompts)
obj = PromptDiscovery(file_pattern=pattern)
discovered_prompts = list(obj.discover())
state.value.prompts = discovered_prompts
names = [prompt.__name__ for prompt in discovered_prompts]
state.value.prompts_names = names

return state


def display_data(key: str, state: PromptState) -> tuple[gr.Textbox, gr.Textbox]:
"""
Displays system and user prompts for a given key from the prompts state.
This function retrieves the system and user prompt text associated with a particular prompt key
from the application state and creates Gradio Textbox elements for them.
Args:
key (str): The key of the prompt to display data for in the prompts state.
state (PromptState): The application state object containing prompt data.
Returns:
tuple[gr.Textbox, gr.Textbox]: A tuple containing two Gradio Textbox elements,
one for the system prompt and one for the user prompt.
"""
data = state.prompts_state[key]

system_prompt_text = gr.Textbox(label="System Prompt", value=data["system_prompt"])
user_prompt_text = gr.Textbox(label="User Prompt", value=data["user_prompt"])

return system_prompt_text, user_prompt_text


def render_prompt(
key: str, system_prompt: str, user_prompt: str, state: gr.State, *args: Any
index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any
) -> tuple[str, str, gr.State]:
"""
Renders a prompt based on the provided key, system prompt, user prompt, and input variables.
Expand All @@ -96,7 +75,7 @@ def render_prompt(
associated with the given key. It then updates the current prompt in the application state.
Args:
key (str): The key of the prompt to render.
index (int): The index of the prompt to render in the prompts state.
system_prompt (str): The system prompt template for the prompt.
user_prompt (str): The user prompt template for the prompt.
state (PromptState): The application state object.
Expand All @@ -108,20 +87,36 @@ def render_prompt(
"""
variables = dict(zip(state.dynamic_tb.keys(), args))

prompt_constructor = state.prompts_state[key]["object"]
prompt_constructor.system_prompt_template = jinja2.Template(system_prompt)
prompt_constructor.user_prompt_template = jinja2.Template(user_prompt)
prompt_class = state.prompts[index]
prompt_class.system_prompt_template = jinja2.Template(system_prompt)
prompt_class.user_prompt_template = jinja2.Template(user_prompt)

input_constructor = state.prompts_state[key]["input_type"]

prompt_object = prompt_constructor(input_data=input_constructor(**variables))
input_type = prompt_class.input_type
input_data = input_type(**variables) if input_type is not None else None
prompt_object = prompt_class(input_data=input_data)
state.current_prompt = prompt_object

chat_dict = {entry["role"]: entry["content"] for entry in prompt_object.chat}

return chat_dict["system"], chat_dict["user"], state


def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]:
"""
Returns a list of prompt choices based on the discovered prompts.
This function generates a list of tuples containing the names of discovered prompts and their
corresponding indices.
Args:
state (gr.State): The application state object.
Returns:
list[tuple[str, int]]: A list of tuples containing prompt names and their indices.
"""
return [(name, idx) for idx, name in enumerate(state.value.prompts_names)]


def send_prompt_to_llm(state: gr.State) -> str:
"""
Sends the current prompt to the LLM and returns the response.
Expand All @@ -147,7 +142,7 @@ def send_prompt_to_llm(state: gr.State) -> str:
return response


def get_input_type_fields(obj: BaseModel) -> list[dict]:
def get_input_type_fields(obj: BaseModel | None) -> list[dict]:
"""
Retrieves the field names and default values from the input type of a prompt.
Expand All @@ -160,6 +155,8 @@ def get_input_type_fields(obj: BaseModel) -> list[dict]:
Returns:
list[dict]: A list of dictionaries, each containing a field name and its default value.
"""
if obj is None:
return []
return [
{"field_name": k, "field_default_value": v["schema"].get("default", None)}
for (k, v) in obj.__pydantic_core_schema__["schema"]["fields"].items()
Expand All @@ -170,57 +167,48 @@ def get_input_type_fields(obj: BaseModel) -> list[dict]:


@typer_app.command()
def run_app(prompts_paths: str, llm_model: str, llm_api_key: str | None = None) -> None:
def run_app(
file_pattern: str = DEFAULT_FILE_PATTERN, llm_model: str | None = None, llm_api_key: str | None = None
) -> None:
"""
Launches the interactive application for working with Large Language Models (LLMs).
This function serves as the entry point for the application. It performs several key tasks:
1. Initializes the application state using the PromptState class.
2. Sets the LLM model name and API key based on user-provided arguments.
3. Fetches a list of prompts from the specified paths using the get_prompts_list function.
3. Fetches a list of prompts from the specified paths using the load_prompts_list function.
4. Creates a Gradio interface with various UI elements:
- A dropdown menu for selecting prompts.
- Textboxes for displaying and potentially modifying system and user prompts.
- Textboxes for entering input values based on the selected prompt.
- Buttons for rendering prompts, sending prompts to the LLM, and displaying the response.
Args:
prompts_paths (str): A comma-separated string of paths to prompt definition files.
file_pattern (str): A pattern for looking up prompt files.
llm_model (str): The name of the LLM model to use.
llm_api_key (str): The API key for the chosen LLM model.
"""
with gr.Blocks() as gr_app:
prompt_state_obj = PromptState()
prompt_state_obj.llm_model_name = llm_model
prompt_state_obj.llm_api_key = llm_api_key

prompts_state = gr.State(value=prompt_state_obj)

prompts_state = get_prompts_list(path=prompts_paths, state=prompts_state)

# Load some values on initialization of the project
# 1) List of available prompts (based on paths provided by user
# 2) Select first prompt as loaded by default
# 3) Set init values of "rendered" fields

prompts_list = list(prompts_state.value.prompts_state.keys())
initial_prompt_key = prompts_list[0]
system_prompt_value = prompts_state.value.prompts_state[initial_prompt_key]["system_prompt"]
user_prompt_value = prompts_state.value.prompts_state[initial_prompt_key]["user_prompt"]

prompt_selection_dropdown = gr.Dropdown(choices=prompts_list, value=initial_prompt_key, label="Select Prompt")

list_of_vars = []
prompts_state = load_prompts_list(pattern=file_pattern, state=prompts_state)
prompt_selection_dropdown = gr.Dropdown(
choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt"
)

@gr.render(inputs=[prompt_selection_dropdown, prompts_state])
def show_split(key: str, state: gr.State) -> None:
def show_split(index: int, state: gr.State) -> None:
prompt = state.prompts[index]
list_of_vars = []
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Inputs"):
if key != "":
input_fields: list = get_input_type_fields(state.prompts_state[key]["input_type"])
if index is not None:
input_fields: list = get_input_type_fields(prompt.input_type)
tb_dict = {}
for entry in input_fields:
with gr.Row():
Expand All @@ -242,7 +230,7 @@ def show_split(key: str, state: gr.State) -> None:
with gr.Row():
with gr.Column():
prompt_details_system_prompt = gr.Textbox(
label="System Prompt", value=system_prompt_value, interactive=True
label="System Prompt", value=prompt.system_prompt, interactive=True
)

with gr.Column():
Expand All @@ -253,22 +241,18 @@ def show_split(key: str, state: gr.State) -> None:
with gr.Row():
with gr.Column():
prompt_details_user_prompt = gr.Textbox(
label="User Prompt", value=user_prompt_value, interactive=True
label="User Prompt", value=prompt.user_prompt, interactive=True
)

with gr.Column():
prompt_details_user_prompt_rendered = gr.Textbox(
label="Rendered User Prompt", value="", interactive=False
)

# TODO: Gray out the "Send to LLM" button if LLM has not been configured
llm_request_button = gr.Button(value="Send to LLM")
llm_prompt_response = gr.Textbox(lines=10, label="LLM response")

prompt_selection_dropdown.change(
display_data,
[prompt_selection_dropdown, prompts_state],
[prompt_details_system_prompt, prompt_details_user_prompt],
)
render_prompt_button.click(
render_prompt,
[
Expand Down
Loading

0 comments on commit 83c0272

Please sign in to comment.