-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add PromptBuilder * Update release note * Add test
- Loading branch information
1 parent
a5b8156
commit 2acc41e
Showing
4 changed files
with
93 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import Dict, Any | ||
|
||
from jinja2 import Template, meta | ||
|
||
from haystack.preview import component | ||
from haystack.preview import default_to_dict, default_from_dict | ||
|
||
|
||
@component | ||
class PromptBuilder: | ||
""" | ||
PromptBuilder is a component that renders a prompt from a template string using Jinja2 engine. | ||
The template variables found in the template string are used as input types for the component and are all required. | ||
Usage: | ||
```python | ||
template = "Translate the following context to {{ target_language }}. Context: {{ snippet }}; Translation:" | ||
builder = PromptBuilder(template=template) | ||
builder.run(target_language="spanish", snippet="I can't speak spanish.") | ||
``` | ||
""" | ||
|
||
def __init__(self, template: str): | ||
""" | ||
Initialize the component with a template string. | ||
:param template: Jinja2 template string, e.g. "Summarize this document: {documents}\nSummary:" | ||
:type template: str | ||
""" | ||
self._template_string = template | ||
self.template = Template(template) | ||
ast = self.template.environment.parse(template) | ||
template_variables = meta.find_undeclared_variables(ast) | ||
component.set_input_types(self, **{var: Any for var in template_variables}) | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
return default_to_dict(self, template=self._template_string) | ||
|
||
@classmethod | ||
def from_dict(cls, data) -> "PromptBuilder": | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(prompt=str) | ||
def run(self, **kwargs): | ||
return {"prompt": self.template.render(kwargs)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,7 @@ dependencies = [ | |
|
||
# Preview | ||
"canals==0.8.0", | ||
"Jinja2", | ||
|
||
# Agent events | ||
"events", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
--- | ||
preview: | ||
- Add PromptBuilder component to render prompts from template strings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
|
||
from haystack.preview.components.generators.prompt_builder import PromptBuilder | ||
|
||
|
||
@pytest.mark.unit | ||
def test_init(): | ||
builder = PromptBuilder(template="This is a {{ variable }}") | ||
assert builder._template_string == "This is a {{ variable }}" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_to_dict(): | ||
builder = PromptBuilder(template="This is a {{ variable }}") | ||
res = builder.to_dict() | ||
assert res == {"type": "PromptBuilder", "init_parameters": {"template": "This is a {{ variable }}"}} | ||
|
||
|
||
@pytest.mark.unit | ||
def test_from_dict(): | ||
data = {"type": "PromptBuilder", "init_parameters": {"template": "This is a {{ variable }}"}} | ||
builder = PromptBuilder.from_dict(data) | ||
builder._template_string == "This is a {{ variable }}" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_run(): | ||
builder = PromptBuilder(template="This is a {{ variable }}") | ||
res = builder.run(variable="test") | ||
assert res == {"prompt": "This is a test"} | ||
|
||
|
||
@pytest.mark.unit | ||
def test_run_without_input(): | ||
builder = PromptBuilder(template="This is a template without input") | ||
res = builder.run() | ||
assert res == {"prompt": "This is a template without input"} | ||
|
||
|
||
@pytest.mark.unit | ||
def test_run_with_missing_input(): | ||
builder = PromptBuilder(template="This is a {{ variable }}") | ||
res = builder.run() | ||
assert res == {"prompt": "This is a "} |