-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add
AnswersBuilder
component (2.0) (#5701)
* Add AnswersBuilder * Add tests for AnswersBuilder * Add release note * PR feedback * Fix mypy * Remove redundant check for number of groups * docstrings upd --------- Co-authored-by: Daria Fokina <[email protected]>
- Loading branch information
1 parent
c5369a3
commit a5b8156
Showing
5 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
171 changes: 171 additions & 0 deletions
171
haystack/preview/components/builders/answers_builder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import logging | ||
import re | ||
from typing import List, Dict, Any, Optional | ||
|
||
from haystack.preview import component, GeneratedAnswer, Document, default_to_dict, default_from_dict | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@component | ||
class AnswersBuilder: | ||
""" | ||
A component to parse the output of a Generator to `Answer` objects using regular expressions. | ||
""" | ||
|
||
def __init__(self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None): | ||
""" | ||
:param pattern: The regular expression pattern to use to extract the answer text from the generator output. | ||
If not specified, the whole string is used as the answer. The regular expression can have at | ||
most one capture group. If a capture group is present, the text matched by the capture group | ||
is used as the answer. If no capture group is present, the whole match is used as the answer. | ||
Examples: | ||
`[^\\n]+$` finds "this is an answer" in a string "this is an argument.\nthis is an answer". | ||
`Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer". | ||
Default: `None`. | ||
:param reference_pattern: The regular expression pattern to use for parsing the document references. | ||
We assume that references are specified as indices of the input documents and that | ||
indices start at 1. | ||
Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". | ||
If not specified, no parsing is done, and all documents are referenced. | ||
Default: `None`. | ||
""" | ||
if pattern: | ||
AnswersBuilder._check_num_groups_in_regex(pattern) | ||
|
||
self.pattern = pattern | ||
self.reference_pattern = reference_pattern | ||
|
||
@component.output_types(answers=List[List[GeneratedAnswer]]) | ||
def run( | ||
self, | ||
queries: List[str], | ||
replies: List[List[str]], | ||
metadata: List[List[Dict[str, Any]]], | ||
documents: Optional[List[List[Document]]] = None, | ||
pattern: Optional[str] = None, | ||
reference_pattern: Optional[str] = None, | ||
): | ||
""" | ||
Parse the output of a Generator to `Answer` objects using regular expressions. | ||
:param queries: The queries used in the prompts for the Generator. A list of strings. | ||
:param replies: The output of the Generator. A list of lists of strings. | ||
:param metadata: The metadata returned by the Generator. A list of lists of dictionaries. | ||
:param documents: The documents used as input to the Generator. A list of lists of `Document` objects. If | ||
`documents` are specified, they are added to the `Answer` objects. | ||
If both `documents` and `reference_pattern` are specified, the documents referenced in the | ||
Generator output are extracted from the input documents and added to the `Answer` objects. | ||
Default: `None`. | ||
:param pattern: The regular expression pattern to use to extract the answer text from the generator output. | ||
If not specified, the whole string is used as the answer. The regular expression can have at | ||
most one capture group. If a capture group is present, the text matched by the capture group | ||
is used as the answer. If no capture group is present, the whole match is used as the answer. | ||
Examples: | ||
`[^\\n]+$` finds "this is an answer" in a string "this is an argument.\nthis is an answer". | ||
`Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer". | ||
Default: `None`. | ||
:param reference_pattern: The regular expression pattern to use for parsing the document references. | ||
We assume that references are specified as indices of the input documents and that | ||
indices start at 1. | ||
Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". | ||
If not specified, no parsing is done, and all documents are referenced. | ||
Default: `None`. | ||
""" | ||
if len(queries) != len(replies) != len(metadata): | ||
raise ValueError( | ||
f"Number of queries ({len(queries)}), replies ({len(replies)}), and metadata " | ||
f"({len(metadata)}) must match." | ||
) | ||
|
||
if pattern: | ||
AnswersBuilder._check_num_groups_in_regex(pattern) | ||
|
||
documents = documents or [] | ||
pattern = pattern or self.pattern | ||
reference_pattern = reference_pattern or self.reference_pattern | ||
|
||
all_answers = [] | ||
for i, (query, reply_list, meta_list) in enumerate(zip(queries, replies, metadata)): | ||
doc_list = documents[i] if i < len(documents) else [] | ||
|
||
extracted_answer_strings = AnswersBuilder._extract_answer_strings(reply_list, pattern) | ||
|
||
if doc_list and reference_pattern: | ||
reference_idxs = AnswersBuilder._extract_reference_idxs(reply_list, reference_pattern) | ||
else: | ||
reference_idxs = [[doc_idx for doc_idx, _ in enumerate(doc_list)] for _ in reply_list] | ||
|
||
answers_for_cur_query = [] | ||
for answer_string, doc_idxs, meta in zip(extracted_answer_strings, reference_idxs, meta_list): | ||
referenced_docs = [] | ||
for idx in doc_idxs: | ||
if idx < len(doc_list): | ||
referenced_docs.append(doc_list[idx]) | ||
else: | ||
logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1) | ||
|
||
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta) | ||
answers_for_cur_query.append(answer) | ||
|
||
all_answers.append(answers_for_cur_query) | ||
|
||
return all_answers | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
""" | ||
return default_to_dict(self, pattern=self.pattern, reference_pattern=self.reference_pattern) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "AnswersBuilder": | ||
""" | ||
Deserialize this component from a dictionary. | ||
""" | ||
return default_from_dict(cls, data) | ||
|
||
@staticmethod | ||
def _extract_answer_strings(replies: List[str], pattern: Optional[str] = None) -> List[str]: | ||
""" | ||
Extract the answer strings from the generator output using the specified pattern. | ||
If no pattern is specified, the whole string is used as the answer. | ||
:param replies: The output of the Generator. A list of strings. | ||
:param pattern: The regular expression pattern to use to extract the answer text from the generator output. | ||
""" | ||
if pattern is None: | ||
return replies | ||
|
||
extracted_answers = [] | ||
for reply in replies: | ||
if match := re.search(pattern, reply): | ||
# No capture group in pattern -> use the whole match as answer | ||
if not match.lastindex: | ||
extracted_answers.append(match.group(0)) | ||
# One capture group in pattern -> use the capture group as answer | ||
else: | ||
extracted_answers.append(match.group(1)) | ||
else: | ||
extracted_answers.append("") | ||
|
||
return extracted_answers | ||
|
||
@staticmethod | ||
def _extract_reference_idxs(replies: List[str], reference_pattern: str) -> List[List[int]]: | ||
reference_idxs = [] | ||
for reply in replies: | ||
document_idxs = re.findall(reference_pattern, reply) | ||
reference_idxs.append([int(idx) - 1 for idx in document_idxs]) | ||
|
||
return reference_idxs | ||
|
||
@staticmethod | ||
def _check_num_groups_in_regex(pattern: str): | ||
num_groups = re.compile(pattern).groups | ||
if num_groups > 1: | ||
raise ValueError( | ||
f"Pattern '{pattern}' contains multiple capture groups. " | ||
f"Please specify a pattern with at most one capture group." | ||
) |
4 changes: 4 additions & 0 deletions
4
releasenotes/notes/add-answersbuilder-2.0-5dd255eeba68041f.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- | ||
preview: | ||
- | | ||
Add the `AnswersBuilder` component for Haystack 2.0 that creates Answer objects from the string output of Generators. |
Empty file.
159 changes: 159 additions & 0 deletions
159
test/preview/components/builders/test_answers_builders.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import logging | ||
|
||
import pytest | ||
|
||
from haystack.preview import GeneratedAnswer, Document | ||
from haystack.preview.components.builders.answers_builder import AnswersBuilder | ||
|
||
|
||
class TestAnswersBuilder: | ||
@pytest.mark.unit | ||
def test_to_dict(self): | ||
component = AnswersBuilder() | ||
data = component.to_dict() | ||
assert data == {"type": "AnswersBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}} | ||
|
||
@pytest.mark.unit | ||
def test_to_dict_with_custom_init_parameters(self): | ||
component = AnswersBuilder(pattern="pattern", reference_pattern="reference_pattern") | ||
data = component.to_dict() | ||
assert data == { | ||
"type": "AnswersBuilder", | ||
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"}, | ||
} | ||
|
||
@pytest.mark.unit | ||
def test_from_dict(self): | ||
data = { | ||
"type": "AnswersBuilder", | ||
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"}, | ||
} | ||
component = AnswersBuilder.from_dict(data) | ||
assert component.pattern == "pattern" | ||
assert component.reference_pattern == "reference_pattern" | ||
|
||
@pytest.mark.unit | ||
def test_run_unmatching_input_len(self): | ||
component = AnswersBuilder() | ||
with pytest.raises(ValueError): | ||
component.run(queries=["query"], replies=[["reply1"], ["reply2"]], metadata=[[]]) | ||
|
||
def test_run_without_pattern(self): | ||
component = AnswersBuilder() | ||
answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) | ||
assert len(answers) == 1 | ||
assert len(answers[0]) == 1 | ||
assert answers[0][0].data == "Answer: AnswerString" | ||
assert answers[0][0].metadata == {} | ||
assert answers[0][0].query == "test query" | ||
assert answers[0][0].documents == [] | ||
assert isinstance(answers[0][0], GeneratedAnswer) | ||
|
||
def test_run_with_pattern_with_capturing_group(self): | ||
component = AnswersBuilder(pattern=r"Answer: (.*)") | ||
answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) | ||
assert len(answers) == 1 | ||
assert len(answers[0]) == 1 | ||
assert answers[0][0].data == "AnswerString" | ||
assert answers[0][0].metadata == {} | ||
assert answers[0][0].query == "test query" | ||
assert answers[0][0].documents == [] | ||
assert isinstance(answers[0][0], GeneratedAnswer) | ||
|
||
def test_run_with_pattern_without_capturing_group(self): | ||
component = AnswersBuilder(pattern=r"'.*'") | ||
answers = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]]) | ||
assert len(answers) == 1 | ||
assert len(answers[0]) == 1 | ||
assert answers[0][0].data == "'AnswerString'" | ||
assert answers[0][0].metadata == {} | ||
assert answers[0][0].query == "test query" | ||
assert answers[0][0].documents == [] | ||
assert isinstance(answers[0][0], GeneratedAnswer) | ||
|
||
def test_run_with_pattern_with_more_than_one_capturing_group(self): | ||
with pytest.raises(ValueError, match="contains multiple capture groups"): | ||
component = AnswersBuilder(pattern=r"Answer: (.*), (.*)") | ||
|
||
def test_run_with_pattern_set_at_runtime(self): | ||
component = AnswersBuilder(pattern="unused pattern") | ||
answers = component.run( | ||
queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)" | ||
) | ||
assert len(answers) == 1 | ||
assert len(answers[0]) == 1 | ||
assert answers[0][0].data == "AnswerString" | ||
assert answers[0][0].metadata == {} | ||
assert answers[0][0].query == "test query" | ||
assert answers[0][0].documents == [] | ||
assert isinstance(answers[0][0], GeneratedAnswer) | ||
|
||
def test_run_with_documents_without_reference_pattern(self): | ||
component = AnswersBuilder() | ||
answers = component.run( | ||
queries=["test query"], | ||
replies=[["Answer: AnswerString"]], | ||
metadata=[[{}]], | ||
documents=[[Document(content="test doc 1"), Document(content="test doc 2")]], | ||
) | ||
assert len(answers) == 1 | ||
assert len(answers[0]) == 1 | ||
assert answers[0][0].data == "Answer: AnswerString" | ||
assert answers[0][0].metadata == {} | ||
assert answers[0][0].query == "test query" | ||
assert len(answers[0][0].documents) == 2 | ||
assert answers[0][0].documents[0].content == "test doc 1" | ||
assert answers[0][0].documents[1].content == "test doc 2" | ||
|
||
def test_run_with_documents_with_reference_pattern(self): | ||
component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]") | ||
answers = component.run( | ||
queries=["test query"], | ||
replies=[["Answer: AnswerString[2]"]], | ||
metadata=[[{}]], | ||
documents=[[Document(content="test doc 1"), Document(content="test doc 2")]], | ||
) | ||
assert len(answers) == 1 | ||
assert len(answers[0]) == 1 | ||
assert answers[0][0].data == "Answer: AnswerString[2]" | ||
assert answers[0][0].metadata == {} | ||
assert answers[0][0].query == "test query" | ||
assert len(answers[0][0].documents) == 1 | ||
assert answers[0][0].documents[0].content == "test doc 2" | ||
|
||
def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog): | ||
component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]") | ||
with caplog.at_level(logging.WARNING): | ||
answers = component.run( | ||
queries=["test query"], | ||
replies=[["Answer: AnswerString[3]"]], | ||
metadata=[[{}]], | ||
documents=[[Document(content="test doc 1"), Document(content="test doc 2")]], | ||
) | ||
assert len(answers) == 1 | ||
assert len(answers[0]) == 1 | ||
assert answers[0][0].data == "Answer: AnswerString[3]" | ||
assert answers[0][0].metadata == {} | ||
assert answers[0][0].query == "test query" | ||
assert len(answers[0][0].documents) == 0 | ||
assert "Document index '3' referenced in Generator output is out of range." in caplog.text | ||
|
||
def test_run_with_reference_pattern_set_at_runtime(self): | ||
component = AnswersBuilder(reference_pattern="unused pattern") | ||
answers = component.run( | ||
queries=["test query"], | ||
replies=[["Answer: AnswerString[2][3]"]], | ||
metadata=[[{}]], | ||
documents=[ | ||
[Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")] | ||
], | ||
reference_pattern="\\[(\\d+)\\]", | ||
) | ||
assert len(answers) == 1 | ||
assert len(answers[0]) == 1 | ||
assert answers[0][0].data == "Answer: AnswerString[2][3]" | ||
assert answers[0][0].metadata == {} | ||
assert answers[0][0].query == "test query" | ||
assert len(answers[0][0].documents) == 2 | ||
assert answers[0][0].documents[0].content == "test doc 2" | ||
assert answers[0][0].documents[1].content == "test doc 3" |