Skip to content

Commit

Permalink
feat: Add Literal["*"] option to required_variables in ChatPrompBuild…
Browse files Browse the repository at this point in the history
…er and PromptBuilder (#8572)

* Add new option for required_variables in PromptBuilder and ChatPromptBuilder

* Add reno note

* Add tests
  • Loading branch information
sjrl authored Nov 22, 2024
1 parent b5a2fad commit eace2a9
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 17 deletions.
22 changes: 14 additions & 8 deletions haystack/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Literal, Optional, Set, Union

from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
Expand Down Expand Up @@ -100,7 +100,7 @@ class ChatPromptBuilder:
def __init__(
self,
template: Optional[List[ChatMessage]] = None,
required_variables: Optional[List[str]] = None,
required_variables: Optional[Union[List[str], Literal["*"]]] = None,
variables: Optional[List[str]] = None,
):
"""
Expand All @@ -112,7 +112,8 @@ def __init__(
the `init` method` or the `run` method.
:param required_variables:
List variables that must be provided as input to ChatPromptBuilder.
If a variable listed as required is not provided, an exception is raised. Optional.
If a variable listed as required is not provided, an exception is raised.
If set to "*", all variables found in the prompt are required. Optional.
:param variables:
List input variables to use in prompt templates instead of the ones inferred from the
`template` parameter. For example, to use more variables during prompt engineering than the ones present
Expand All @@ -127,14 +128,15 @@ def __init__(
if template and not variables:
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infere variables from template
# infer variables from template
ast = self._env.parse(message.content)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
self.variables = variables

# setup inputs
for var in variables:
if var in self.required_variables:
for var in self.variables:
if self.required_variables == "*" or var in self.required_variables:
component.set_input_type(self, var, Any)
else:
component.set_input_type(self, var, Any, "")
Expand Down Expand Up @@ -211,12 +213,16 @@ def _validate_variables(self, provided_variables: Set[str]):
:raises ValueError:
If no template is provided or if all the required template variables are not provided.
"""
missing_variables = [var for var in self.required_variables if var not in provided_variables]
if self.required_variables == "*":
required_variables = sorted(self.variables)
else:
required_variables = self.required_variables
missing_variables = [var for var in required_variables if var not in provided_variables]
if missing_variables:
missing_vars_str = ", ".join(missing_variables)
raise ValueError(
f"Missing required input variables in ChatPromptBuilder: {missing_vars_str}. "
f"Required variables: {self.required_variables}. Provided variables: {provided_variables}."
f"Required variables: {required_variables}. Provided variables: {provided_variables}."
)

def to_dict(self) -> Dict[str, Any]:
Expand Down
24 changes: 16 additions & 8 deletions haystack/components/builders/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Literal, Optional, Set, Union

from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
Expand Down Expand Up @@ -137,7 +137,10 @@ class PromptBuilder:
"""

def __init__(
self, template: str, required_variables: Optional[List[str]] = None, variables: Optional[List[str]] = None
self,
template: str,
required_variables: Optional[Union[List[str], Literal["*"]]] = None,
variables: Optional[List[str]] = None,
):
"""
Constructs a PromptBuilder component.
Expand All @@ -150,7 +153,8 @@ def __init__(
unless explicitly specified.
If an optional variable is not provided, it's replaced with an empty string in the rendered prompt.
:param required_variables: List variables that must be provided as input to PromptBuilder.
If a variable listed as required is not provided, an exception is raised. Optional.
If a variable listed as required is not provided, an exception is raised.
If set to "*", all variables found in the prompt are required. Optional.
:param variables:
List input variables to use in prompt templates instead of the ones inferred from the
`template` parameter. For example, to use more variables during prompt engineering than the ones present
Expand All @@ -173,12 +177,12 @@ def __init__(
ast = self._env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)

variables = variables or []
self.variables = variables

# setup inputs
for var in variables:
if var in self.required_variables:
for var in self.variables:
if self.required_variables == "*" or var in self.required_variables:
component.set_input_type(self, var, Any)
else:
component.set_input_type(self, var, Any, "")
Expand Down Expand Up @@ -238,10 +242,14 @@ def _validate_variables(self, provided_variables: Set[str]):
:raises ValueError:
If any of the required template variables is not provided.
"""
missing_variables = [var for var in self.required_variables if var not in provided_variables]
if self.required_variables == "*":
required_variables = sorted(self.variables)
else:
required_variables = self.required_variables
missing_variables = [var for var in required_variables if var not in provided_variables]
if missing_variables:
missing_vars_str = ", ".join(missing_variables)
raise ValueError(
f"Missing required input variables in PromptBuilder: {missing_vars_str}. "
f"Required variables: {self.required_variables}. Provided variables: {provided_variables}."
f"Required variables: {required_variables}. Provided variables: {provided_variables}."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Added a new option to the required_variables parameter to the PromptBuilder and ChatPromptBuilder.
By passing `required_variables="*"` you can automatically set all variables in the prompt to be required.
11 changes: 11 additions & 0 deletions test/components/builders/test_chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ def test_run_with_missing_required_input(self):
with pytest.raises(ValueError, match="foo, bar"):
builder.run()

def test_run_with_missing_required_input_using_star(self):
builder = ChatPromptBuilder(
template=[ChatMessage.from_user("This is a {{ foo }}, not a {{ bar }}")], required_variables="*"
)
with pytest.raises(ValueError, match="foo"):
builder.run(bar="bar")
with pytest.raises(ValueError, match="bar"):
builder.run(foo="foo")
with pytest.raises(ValueError, match="bar, foo"):
builder.run()

def test_run_with_variables(self):
variables = ["var1", "var2", "var3"]
template = [ChatMessage.from_user("Hello, {{ name }}! {{ var1 }}")]
Expand Down
11 changes: 10 additions & 1 deletion test/components/builders/test_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ def test_run_with_missing_required_input(self):
with pytest.raises(ValueError, match="foo, bar"):
builder.run()

def test_run_with_missing_required_input_using_star(self):
builder = PromptBuilder(template="This is a {{ foo }}, not a {{ bar }}", required_variables="*")
with pytest.raises(ValueError, match="foo"):
builder.run(bar="bar")
with pytest.raises(ValueError, match="bar"):
builder.run(foo="foo")
with pytest.raises(ValueError, match="bar, foo"):
builder.run()

def test_run_with_variables(self):
variables = ["var1", "var2", "var3"]
template = "Hello, {{ name }}! {{ var1 }}"
Expand Down Expand Up @@ -296,7 +305,7 @@ def test_date_with_addition_offset(self) -> None:

assert now_plus_2 == result

def test_date_with_substraction_offset(self) -> None:
def test_date_with_subtraction_offset(self) -> None:
template = "Time after 12 days is: {% now 'UTC' - 'days=12' %}"
builder = PromptBuilder(template=template)

Expand Down

0 comments on commit eace2a9

Please sign in to comment.