From a5b815690ed7343882603a675c621ffc4c129c9b Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Mon, 4 Sep 2023 21:16:20 +0200 Subject: [PATCH] feat: Add `AnswersBuilder` component (2.0) (#5701) * Add AnswersBuilder * Add tests for AnswersBuilder * Add release note * PR feedback * Fix mypy * Remove redundant check for number of groups * docstrings upd --------- Co-authored-by: Daria Fokina --- .../preview/components/builders/__init__.py | 0 .../components/builders/answers_builder.py | 171 ++++++++++++++++++ ...d-answersbuilder-2.0-5dd255eeba68041f.yaml | 4 + test/preview/components/builders/__init__.py | 0 .../builders/test_answers_builders.py | 159 ++++++++++++++++ 5 files changed, 334 insertions(+) create mode 100644 haystack/preview/components/builders/__init__.py create mode 100644 haystack/preview/components/builders/answers_builder.py create mode 100644 releasenotes/notes/add-answersbuilder-2.0-5dd255eeba68041f.yaml create mode 100644 test/preview/components/builders/__init__.py create mode 100644 test/preview/components/builders/test_answers_builders.py diff --git a/haystack/preview/components/builders/__init__.py b/haystack/preview/components/builders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/haystack/preview/components/builders/answers_builder.py b/haystack/preview/components/builders/answers_builder.py new file mode 100644 index 0000000000..f9223679bf --- /dev/null +++ b/haystack/preview/components/builders/answers_builder.py @@ -0,0 +1,171 @@ +import logging +import re +from typing import List, Dict, Any, Optional + +from haystack.preview import component, GeneratedAnswer, Document, default_to_dict, default_from_dict + + +logger = logging.getLogger(__name__) + + +@component +class AnswersBuilder: + """ + A component to parse the output of a Generator to `Answer` objects using regular expressions. + """ + + def __init__(self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None): + """ + :param pattern: The regular expression pattern to use to extract the answer text from the generator output. + If not specified, the whole string is used as the answer. The regular expression can have at + most one capture group. If a capture group is present, the text matched by the capture group + is used as the answer. If no capture group is present, the whole match is used as the answer. + Examples: + `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\nthis is an answer". + `Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer". + Default: `None`. + :param reference_pattern: The regular expression pattern to use for parsing the document references. + We assume that references are specified as indices of the input documents and that + indices start at 1. + Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". + If not specified, no parsing is done, and all documents are referenced. + Default: `None`. + """ + if pattern: + AnswersBuilder._check_num_groups_in_regex(pattern) + + self.pattern = pattern + self.reference_pattern = reference_pattern + + @component.output_types(answers=List[List[GeneratedAnswer]]) + def run( + self, + queries: List[str], + replies: List[List[str]], + metadata: List[List[Dict[str, Any]]], + documents: Optional[List[List[Document]]] = None, + pattern: Optional[str] = None, + reference_pattern: Optional[str] = None, + ): + """ + Parse the output of a Generator to `Answer` objects using regular expressions. + + :param queries: The queries used in the prompts for the Generator. A list of strings. + :param replies: The output of the Generator. A list of lists of strings. + :param metadata: The metadata returned by the Generator. A list of lists of dictionaries. + :param documents: The documents used as input to the Generator. A list of lists of `Document` objects. If + `documents` are specified, they are added to the `Answer` objects. + If both `documents` and `reference_pattern` are specified, the documents referenced in the + Generator output are extracted from the input documents and added to the `Answer` objects. + Default: `None`. + :param pattern: The regular expression pattern to use to extract the answer text from the generator output. + If not specified, the whole string is used as the answer. The regular expression can have at + most one capture group. If a capture group is present, the text matched by the capture group + is used as the answer. If no capture group is present, the whole match is used as the answer. + Examples: + `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\nthis is an answer". + `Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer". + Default: `None`. + :param reference_pattern: The regular expression pattern to use for parsing the document references. + We assume that references are specified as indices of the input documents and that + indices start at 1. + Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". + If not specified, no parsing is done, and all documents are referenced. + Default: `None`. + """ + if len(queries) != len(replies) != len(metadata): + raise ValueError( + f"Number of queries ({len(queries)}), replies ({len(replies)}), and metadata " + f"({len(metadata)}) must match." + ) + + if pattern: + AnswersBuilder._check_num_groups_in_regex(pattern) + + documents = documents or [] + pattern = pattern or self.pattern + reference_pattern = reference_pattern or self.reference_pattern + + all_answers = [] + for i, (query, reply_list, meta_list) in enumerate(zip(queries, replies, metadata)): + doc_list = documents[i] if i < len(documents) else [] + + extracted_answer_strings = AnswersBuilder._extract_answer_strings(reply_list, pattern) + + if doc_list and reference_pattern: + reference_idxs = AnswersBuilder._extract_reference_idxs(reply_list, reference_pattern) + else: + reference_idxs = [[doc_idx for doc_idx, _ in enumerate(doc_list)] for _ in reply_list] + + answers_for_cur_query = [] + for answer_string, doc_idxs, meta in zip(extracted_answer_strings, reference_idxs, meta_list): + referenced_docs = [] + for idx in doc_idxs: + if idx < len(doc_list): + referenced_docs.append(doc_list[idx]) + else: + logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1) + + answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta) + answers_for_cur_query.append(answer) + + all_answers.append(answers_for_cur_query) + + return all_answers + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict(self, pattern=self.pattern, reference_pattern=self.reference_pattern) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnswersBuilder": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + + @staticmethod + def _extract_answer_strings(replies: List[str], pattern: Optional[str] = None) -> List[str]: + """ + Extract the answer strings from the generator output using the specified pattern. + If no pattern is specified, the whole string is used as the answer. + + :param replies: The output of the Generator. A list of strings. + :param pattern: The regular expression pattern to use to extract the answer text from the generator output. + """ + if pattern is None: + return replies + + extracted_answers = [] + for reply in replies: + if match := re.search(pattern, reply): + # No capture group in pattern -> use the whole match as answer + if not match.lastindex: + extracted_answers.append(match.group(0)) + # One capture group in pattern -> use the capture group as answer + else: + extracted_answers.append(match.group(1)) + else: + extracted_answers.append("") + + return extracted_answers + + @staticmethod + def _extract_reference_idxs(replies: List[str], reference_pattern: str) -> List[List[int]]: + reference_idxs = [] + for reply in replies: + document_idxs = re.findall(reference_pattern, reply) + reference_idxs.append([int(idx) - 1 for idx in document_idxs]) + + return reference_idxs + + @staticmethod + def _check_num_groups_in_regex(pattern: str): + num_groups = re.compile(pattern).groups + if num_groups > 1: + raise ValueError( + f"Pattern '{pattern}' contains multiple capture groups. " + f"Please specify a pattern with at most one capture group." + ) diff --git a/releasenotes/notes/add-answersbuilder-2.0-5dd255eeba68041f.yaml b/releasenotes/notes/add-answersbuilder-2.0-5dd255eeba68041f.yaml new file mode 100644 index 0000000000..a131a83aeb --- /dev/null +++ b/releasenotes/notes/add-answersbuilder-2.0-5dd255eeba68041f.yaml @@ -0,0 +1,4 @@ +--- +preview: + - | + Add the `AnswersBuilder` component for Haystack 2.0 that creates Answer objects from the string output of Generators. diff --git a/test/preview/components/builders/__init__.py b/test/preview/components/builders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/preview/components/builders/test_answers_builders.py b/test/preview/components/builders/test_answers_builders.py new file mode 100644 index 0000000000..8d4e51db6b --- /dev/null +++ b/test/preview/components/builders/test_answers_builders.py @@ -0,0 +1,159 @@ +import logging + +import pytest + +from haystack.preview import GeneratedAnswer, Document +from haystack.preview.components.builders.answers_builder import AnswersBuilder + + +class TestAnswersBuilder: + @pytest.mark.unit + def test_to_dict(self): + component = AnswersBuilder() + data = component.to_dict() + assert data == {"type": "AnswersBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}} + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + component = AnswersBuilder(pattern="pattern", reference_pattern="reference_pattern") + data = component.to_dict() + assert data == { + "type": "AnswersBuilder", + "init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"}, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "AnswersBuilder", + "init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"}, + } + component = AnswersBuilder.from_dict(data) + assert component.pattern == "pattern" + assert component.reference_pattern == "reference_pattern" + + @pytest.mark.unit + def test_run_unmatching_input_len(self): + component = AnswersBuilder() + with pytest.raises(ValueError): + component.run(queries=["query"], replies=[["reply1"], ["reply2"]], metadata=[[]]) + + def test_run_without_pattern(self): + component = AnswersBuilder() + answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) + assert len(answers) == 1 + assert len(answers[0]) == 1 + assert answers[0][0].data == "Answer: AnswerString" + assert answers[0][0].metadata == {} + assert answers[0][0].query == "test query" + assert answers[0][0].documents == [] + assert isinstance(answers[0][0], GeneratedAnswer) + + def test_run_with_pattern_with_capturing_group(self): + component = AnswersBuilder(pattern=r"Answer: (.*)") + answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]]) + assert len(answers) == 1 + assert len(answers[0]) == 1 + assert answers[0][0].data == "AnswerString" + assert answers[0][0].metadata == {} + assert answers[0][0].query == "test query" + assert answers[0][0].documents == [] + assert isinstance(answers[0][0], GeneratedAnswer) + + def test_run_with_pattern_without_capturing_group(self): + component = AnswersBuilder(pattern=r"'.*'") + answers = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]]) + assert len(answers) == 1 + assert len(answers[0]) == 1 + assert answers[0][0].data == "'AnswerString'" + assert answers[0][0].metadata == {} + assert answers[0][0].query == "test query" + assert answers[0][0].documents == [] + assert isinstance(answers[0][0], GeneratedAnswer) + + def test_run_with_pattern_with_more_than_one_capturing_group(self): + with pytest.raises(ValueError, match="contains multiple capture groups"): + component = AnswersBuilder(pattern=r"Answer: (.*), (.*)") + + def test_run_with_pattern_set_at_runtime(self): + component = AnswersBuilder(pattern="unused pattern") + answers = component.run( + queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)" + ) + assert len(answers) == 1 + assert len(answers[0]) == 1 + assert answers[0][0].data == "AnswerString" + assert answers[0][0].metadata == {} + assert answers[0][0].query == "test query" + assert answers[0][0].documents == [] + assert isinstance(answers[0][0], GeneratedAnswer) + + def test_run_with_documents_without_reference_pattern(self): + component = AnswersBuilder() + answers = component.run( + queries=["test query"], + replies=[["Answer: AnswerString"]], + metadata=[[{}]], + documents=[[Document(content="test doc 1"), Document(content="test doc 2")]], + ) + assert len(answers) == 1 + assert len(answers[0]) == 1 + assert answers[0][0].data == "Answer: AnswerString" + assert answers[0][0].metadata == {} + assert answers[0][0].query == "test query" + assert len(answers[0][0].documents) == 2 + assert answers[0][0].documents[0].content == "test doc 1" + assert answers[0][0].documents[1].content == "test doc 2" + + def test_run_with_documents_with_reference_pattern(self): + component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]") + answers = component.run( + queries=["test query"], + replies=[["Answer: AnswerString[2]"]], + metadata=[[{}]], + documents=[[Document(content="test doc 1"), Document(content="test doc 2")]], + ) + assert len(answers) == 1 + assert len(answers[0]) == 1 + assert answers[0][0].data == "Answer: AnswerString[2]" + assert answers[0][0].metadata == {} + assert answers[0][0].query == "test query" + assert len(answers[0][0].documents) == 1 + assert answers[0][0].documents[0].content == "test doc 2" + + def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog): + component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]") + with caplog.at_level(logging.WARNING): + answers = component.run( + queries=["test query"], + replies=[["Answer: AnswerString[3]"]], + metadata=[[{}]], + documents=[[Document(content="test doc 1"), Document(content="test doc 2")]], + ) + assert len(answers) == 1 + assert len(answers[0]) == 1 + assert answers[0][0].data == "Answer: AnswerString[3]" + assert answers[0][0].metadata == {} + assert answers[0][0].query == "test query" + assert len(answers[0][0].documents) == 0 + assert "Document index '3' referenced in Generator output is out of range." in caplog.text + + def test_run_with_reference_pattern_set_at_runtime(self): + component = AnswersBuilder(reference_pattern="unused pattern") + answers = component.run( + queries=["test query"], + replies=[["Answer: AnswerString[2][3]"]], + metadata=[[{}]], + documents=[ + [Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")] + ], + reference_pattern="\\[(\\d+)\\]", + ) + assert len(answers) == 1 + assert len(answers[0]) == 1 + assert answers[0][0].data == "Answer: AnswerString[2][3]" + assert answers[0][0].metadata == {} + assert answers[0][0].query == "test query" + assert len(answers[0][0].documents) == 2 + assert answers[0][0].documents[0].content == "test doc 2" + assert answers[0][0].documents[1].content == "test doc 3"