Skip to content

Commit

Permalink
feat: Add AnswersBuilder component (2.0) (#5701)
Browse files Browse the repository at this point in the history
* Add AnswersBuilder

* Add tests for AnswersBuilder

* Add release note

* PR feedback

* Fix mypy

* Remove redundant check for number of groups

* docstrings upd

---------

Co-authored-by: Daria Fokina <[email protected]>
  • Loading branch information
bogdankostic and dfokina authored Sep 4, 2023
1 parent c5369a3 commit a5b8156
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 0 deletions.
Empty file.
171 changes: 171 additions & 0 deletions haystack/preview/components/builders/answers_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import logging
import re
from typing import List, Dict, Any, Optional

from haystack.preview import component, GeneratedAnswer, Document, default_to_dict, default_from_dict


logger = logging.getLogger(__name__)


@component
class AnswersBuilder:
"""
A component to parse the output of a Generator to `Answer` objects using regular expressions.
"""

def __init__(self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None):
"""
:param pattern: The regular expression pattern to use to extract the answer text from the generator output.
If not specified, the whole string is used as the answer. The regular expression can have at
most one capture group. If a capture group is present, the text matched by the capture group
is used as the answer. If no capture group is present, the whole match is used as the answer.
Examples:
`[^\\n]+$` finds "this is an answer" in a string "this is an argument.\nthis is an answer".
`Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer".
Default: `None`.
:param reference_pattern: The regular expression pattern to use for parsing the document references.
We assume that references are specified as indices of the input documents and that
indices start at 1.
Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]".
If not specified, no parsing is done, and all documents are referenced.
Default: `None`.
"""
if pattern:
AnswersBuilder._check_num_groups_in_regex(pattern)

self.pattern = pattern
self.reference_pattern = reference_pattern

@component.output_types(answers=List[List[GeneratedAnswer]])
def run(
self,
queries: List[str],
replies: List[List[str]],
metadata: List[List[Dict[str, Any]]],
documents: Optional[List[List[Document]]] = None,
pattern: Optional[str] = None,
reference_pattern: Optional[str] = None,
):
"""
Parse the output of a Generator to `Answer` objects using regular expressions.
:param queries: The queries used in the prompts for the Generator. A list of strings.
:param replies: The output of the Generator. A list of lists of strings.
:param metadata: The metadata returned by the Generator. A list of lists of dictionaries.
:param documents: The documents used as input to the Generator. A list of lists of `Document` objects. If
`documents` are specified, they are added to the `Answer` objects.
If both `documents` and `reference_pattern` are specified, the documents referenced in the
Generator output are extracted from the input documents and added to the `Answer` objects.
Default: `None`.
:param pattern: The regular expression pattern to use to extract the answer text from the generator output.
If not specified, the whole string is used as the answer. The regular expression can have at
most one capture group. If a capture group is present, the text matched by the capture group
is used as the answer. If no capture group is present, the whole match is used as the answer.
Examples:
`[^\\n]+$` finds "this is an answer" in a string "this is an argument.\nthis is an answer".
`Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer".
Default: `None`.
:param reference_pattern: The regular expression pattern to use for parsing the document references.
We assume that references are specified as indices of the input documents and that
indices start at 1.
Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]".
If not specified, no parsing is done, and all documents are referenced.
Default: `None`.
"""
if len(queries) != len(replies) != len(metadata):
raise ValueError(
f"Number of queries ({len(queries)}), replies ({len(replies)}), and metadata "
f"({len(metadata)}) must match."
)

if pattern:
AnswersBuilder._check_num_groups_in_regex(pattern)

documents = documents or []
pattern = pattern or self.pattern
reference_pattern = reference_pattern or self.reference_pattern

all_answers = []
for i, (query, reply_list, meta_list) in enumerate(zip(queries, replies, metadata)):
doc_list = documents[i] if i < len(documents) else []

extracted_answer_strings = AnswersBuilder._extract_answer_strings(reply_list, pattern)

if doc_list and reference_pattern:
reference_idxs = AnswersBuilder._extract_reference_idxs(reply_list, reference_pattern)
else:
reference_idxs = [[doc_idx for doc_idx, _ in enumerate(doc_list)] for _ in reply_list]

answers_for_cur_query = []
for answer_string, doc_idxs, meta in zip(extracted_answer_strings, reference_idxs, meta_list):
referenced_docs = []
for idx in doc_idxs:
if idx < len(doc_list):
referenced_docs.append(doc_list[idx])
else:
logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1)

answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta)
answers_for_cur_query.append(answer)

all_answers.append(answers_for_cur_query)

return all_answers

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, pattern=self.pattern, reference_pattern=self.reference_pattern)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AnswersBuilder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

@staticmethod
def _extract_answer_strings(replies: List[str], pattern: Optional[str] = None) -> List[str]:
"""
Extract the answer strings from the generator output using the specified pattern.
If no pattern is specified, the whole string is used as the answer.
:param replies: The output of the Generator. A list of strings.
:param pattern: The regular expression pattern to use to extract the answer text from the generator output.
"""
if pattern is None:
return replies

extracted_answers = []
for reply in replies:
if match := re.search(pattern, reply):
# No capture group in pattern -> use the whole match as answer
if not match.lastindex:
extracted_answers.append(match.group(0))
# One capture group in pattern -> use the capture group as answer
else:
extracted_answers.append(match.group(1))
else:
extracted_answers.append("")

return extracted_answers

@staticmethod
def _extract_reference_idxs(replies: List[str], reference_pattern: str) -> List[List[int]]:
reference_idxs = []
for reply in replies:
document_idxs = re.findall(reference_pattern, reply)
reference_idxs.append([int(idx) - 1 for idx in document_idxs])

return reference_idxs

@staticmethod
def _check_num_groups_in_regex(pattern: str):
num_groups = re.compile(pattern).groups
if num_groups > 1:
raise ValueError(
f"Pattern '{pattern}' contains multiple capture groups. "
f"Please specify a pattern with at most one capture group."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Add the `AnswersBuilder` component for Haystack 2.0 that creates Answer objects from the string output of Generators.
Empty file.
159 changes: 159 additions & 0 deletions test/preview/components/builders/test_answers_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import logging

import pytest

from haystack.preview import GeneratedAnswer, Document
from haystack.preview.components.builders.answers_builder import AnswersBuilder


class TestAnswersBuilder:
@pytest.mark.unit
def test_to_dict(self):
component = AnswersBuilder()
data = component.to_dict()
assert data == {"type": "AnswersBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}}

@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = AnswersBuilder(pattern="pattern", reference_pattern="reference_pattern")
data = component.to_dict()
assert data == {
"type": "AnswersBuilder",
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
}

@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "AnswersBuilder",
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
}
component = AnswersBuilder.from_dict(data)
assert component.pattern == "pattern"
assert component.reference_pattern == "reference_pattern"

@pytest.mark.unit
def test_run_unmatching_input_len(self):
component = AnswersBuilder()
with pytest.raises(ValueError):
component.run(queries=["query"], replies=[["reply1"], ["reply2"]], metadata=[[]])

def test_run_without_pattern(self):
component = AnswersBuilder()
answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert answers[0][0].documents == []
assert isinstance(answers[0][0], GeneratedAnswer)

def test_run_with_pattern_with_capturing_group(self):
component = AnswersBuilder(pattern=r"Answer: (.*)")
answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "AnswerString"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert answers[0][0].documents == []
assert isinstance(answers[0][0], GeneratedAnswer)

def test_run_with_pattern_without_capturing_group(self):
component = AnswersBuilder(pattern=r"'.*'")
answers = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]])
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "'AnswerString'"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert answers[0][0].documents == []
assert isinstance(answers[0][0], GeneratedAnswer)

def test_run_with_pattern_with_more_than_one_capturing_group(self):
with pytest.raises(ValueError, match="contains multiple capture groups"):
component = AnswersBuilder(pattern=r"Answer: (.*), (.*)")

def test_run_with_pattern_set_at_runtime(self):
component = AnswersBuilder(pattern="unused pattern")
answers = component.run(
queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)"
)
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "AnswerString"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert answers[0][0].documents == []
assert isinstance(answers[0][0], GeneratedAnswer)

def test_run_with_documents_without_reference_pattern(self):
component = AnswersBuilder()
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString"]],
metadata=[[{}]],
documents=[[Document(content="test doc 1"), Document(content="test doc 2")]],
)
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert len(answers[0][0].documents) == 2
assert answers[0][0].documents[0].content == "test doc 1"
assert answers[0][0].documents[1].content == "test doc 2"

def test_run_with_documents_with_reference_pattern(self):
component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]")
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[2]"]],
metadata=[[{}]],
documents=[[Document(content="test doc 1"), Document(content="test doc 2")]],
)
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString[2]"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert len(answers[0][0].documents) == 1
assert answers[0][0].documents[0].content == "test doc 2"

def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog):
component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]")
with caplog.at_level(logging.WARNING):
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[3]"]],
metadata=[[{}]],
documents=[[Document(content="test doc 1"), Document(content="test doc 2")]],
)
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString[3]"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert len(answers[0][0].documents) == 0
assert "Document index '3' referenced in Generator output is out of range." in caplog.text

def test_run_with_reference_pattern_set_at_runtime(self):
component = AnswersBuilder(reference_pattern="unused pattern")
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[2][3]"]],
metadata=[[{}]],
documents=[
[Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")]
],
reference_pattern="\\[(\\d+)\\]",
)
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString[2][3]"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert len(answers[0][0].documents) == 2
assert answers[0][0].documents[0].content == "test doc 2"
assert answers[0][0].documents[1].content == "test doc 3"

0 comments on commit a5b8156

Please sign in to comment.