From bd77120cf347e4efcddb534ddca4dbb5f62882c4 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:54:30 +0100 Subject: [PATCH] Fix `DocumentSplitter` not splitting by function (#8549) * Fix DocumentSplitter not splitting by function * Make the split_by mapping a constant --- .../preprocessors/document_splitter.py | 74 +++++++++---------- .../preprocessors/nltk_document_splitter.py | 2 +- .../split-by-function-62ce32fac70d8f8c.yaml | 5 ++ .../preprocessors/test_document_splitter.py | 37 ++++++---- 4 files changed, 67 insertions(+), 51 deletions(-) create mode 100644 releasenotes/notes/split-by-function-62ce32fac70d8f8c.yaml diff --git a/haystack/components/preprocessors/document_splitter.py b/haystack/components/preprocessors/document_splitter.py index 186f145e94..86d95f412a 100644 --- a/haystack/components/preprocessors/document_splitter.py +++ b/haystack/components/preprocessors/document_splitter.py @@ -13,6 +13,10 @@ logger = logging.getLogger(__name__) +# Maps the 'split_by' argument to the actual char used to split the Documents. +# 'function' is not in the mapping cause it doesn't split on chars. +_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "sentence": ".", "word": " ", "line": "\n"} + @component class DocumentSplitter: @@ -73,7 +77,7 @@ def __init__( # pylint: disable=too-many-positional-arguments self.split_by = split_by if split_by not in ["function", "page", "passage", "sentence", "word", "line"]: - raise ValueError("split_by must be one of 'word', 'sentence', 'page', 'passage' or 'line'.") + raise ValueError("split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'.") if split_by == "function" and splitting_function is None: raise ValueError("When 'split_by' is set to 'function', a valid 'splitting_function' must be provided.") if split_length <= 0: @@ -108,7 +112,7 @@ def run(self, documents: List[Document]): if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): raise TypeError("DocumentSplitter expects a List of Documents as input.") - split_docs = [] + split_docs: List[Document] = [] for doc in documents: if doc.content is None: raise ValueError( @@ -117,42 +121,38 @@ def run(self, documents: List[Document]): if doc.content == "": logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id) continue - units = self._split_into_units(doc.content, self.split_by) - text_splits, splits_pages, splits_start_idxs = self._concatenate_units( - units, self.split_length, self.split_overlap, self.split_threshold - ) - metadata = deepcopy(doc.meta) - metadata["source_id"] = doc.id - split_docs += self._create_docs_from_splits( - text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata - ) + split_docs += self._split(doc) return {"documents": split_docs} - def _split_into_units( - self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word", "line"] - ) -> List[str]: - if split_by == "page": - self.split_at = "\f" - elif split_by == "passage": - self.split_at = "\n\n" - elif split_by == "sentence": - self.split_at = "." - elif split_by == "word": - self.split_at = " " - elif split_by == "line": - self.split_at = "\n" - elif split_by == "function" and self.splitting_function is not None: - return self.splitting_function(text) - else: - raise NotImplementedError( - """DocumentSplitter only supports 'function', 'line', 'page', - 'passage', 'sentence' or 'word' split_by options.""" - ) - units = text.split(self.split_at) + def _split(self, to_split: Document) -> List[Document]: + # We already check this before calling _split but + # we need to make linters happy + if to_split.content is None: + return [] + + if self.split_by == "function" and self.splitting_function is not None: + splits = self.splitting_function(to_split.content) + docs: List[Document] = [] + for s in splits: + meta = deepcopy(to_split.meta) + meta["source_id"] = to_split.id + docs.append(Document(content=s, meta=meta)) + return docs + + split_at = _SPLIT_BY_MAPPING[self.split_by] + units = to_split.content.split(split_at) # Add the delimiter back to all units except the last one for i in range(len(units) - 1): - units[i] += self.split_at - return units + units[i] += split_at + + text_splits, splits_pages, splits_start_idxs = self._concatenate_units( + units, self.split_length, self.split_overlap, self.split_threshold + ) + metadata = deepcopy(to_split.meta) + metadata["source_id"] = to_split.id + return self._create_docs_from_splits( + text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata + ) def _concatenate_units( self, elements: List[str], split_length: int, split_overlap: int, split_threshold: int @@ -166,8 +166,8 @@ def _concatenate_units( """ text_splits: List[str] = [] - splits_pages = [] - splits_start_idxs = [] + splits_pages: List[int] = [] + splits_start_idxs: List[int] = [] cur_start_idx = 0 cur_page = 1 segments = windowed(elements, n=split_length, step=split_length - split_overlap) @@ -200,7 +200,7 @@ def _concatenate_units( return text_splits, splits_pages, splits_start_idxs def _create_docs_from_splits( - self, text_splits: List[str], splits_pages: List[int], splits_start_idxs: List[int], meta: Dict + self, text_splits: List[str], splits_pages: List[int], splits_start_idxs: List[int], meta: Dict[str, Any] ) -> List[Document]: """ Creates Document objects from splits enriching them with page number and the metadata of the original document. diff --git a/haystack/components/preprocessors/nltk_document_splitter.py b/haystack/components/preprocessors/nltk_document_splitter.py index 6826c22ecc..f571396311 100644 --- a/haystack/components/preprocessors/nltk_document_splitter.py +++ b/haystack/components/preprocessors/nltk_document_splitter.py @@ -87,7 +87,7 @@ def __init__( # pylint: disable=too-many-positional-arguments self.language = language def _split_into_units( - self, text: str, split_by: Literal["word", "sentence", "passage", "page", "function"] + self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word", "line"] ) -> List[str]: """ Splits the text into units based on the specified split_by parameter. diff --git a/releasenotes/notes/split-by-function-62ce32fac70d8f8c.yaml b/releasenotes/notes/split-by-function-62ce32fac70d8f8c.yaml new file mode 100644 index 0000000000..495b896f4c --- /dev/null +++ b/releasenotes/notes/split-by-function-62ce32fac70d8f8c.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fix `DocumentSplitter` to handle custom `splitting_function` without requiring `split_length`. + Previously the `splitting_function` provided would not override other settings. diff --git a/test/components/preprocessors/test_document_splitter.py b/test/components/preprocessors/test_document_splitter.py index b21dcbf7d8..25872626c1 100644 --- a/test/components/preprocessors/test_document_splitter.py +++ b/test/components/preprocessors/test_document_splitter.py @@ -57,7 +57,7 @@ def test_empty_list(self): def test_unsupported_split_by(self): with pytest.raises( - ValueError, match="split_by must be one of 'word', 'sentence', 'page', 'passage' or 'line'." + ValueError, match="split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'." ): DocumentSplitter(split_by="unsupported") @@ -177,25 +177,36 @@ def test_split_by_page(self): assert docs[2].meta["page_number"] == 3 def test_split_by_function(self): - splitting_function = lambda input_str: input_str.split(".") - splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1) + splitting_function = lambda s: s.split(".") + splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function) text = "This.Is.A.Test" - result = splitter.run(documents=[Document(content=text)]) + result = splitter.run(documents=[Document(id="1", content=text, meta={"key": "value"})]) docs = result["documents"] - word_list = ["This", "Is", "A", "Test"] assert len(docs) == 4 - for w_target, w_split in zip(word_list, docs): - assert w_split.content == w_target - - splitting_function = lambda input_str: re.split("[\s]{2,}", input_str) - splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1) + assert docs[0].content == "This" + assert docs[0].meta == {"key": "value", "source_id": "1"} + assert docs[1].content == "Is" + assert docs[1].meta == {"key": "value", "source_id": "1"} + assert docs[2].content == "A" + assert docs[2].meta == {"key": "value", "source_id": "1"} + assert docs[3].content == "Test" + assert docs[3].meta == {"key": "value", "source_id": "1"} + + splitting_function = lambda s: re.split(r"[\s]{2,}", s) + splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function) text = "This Is\n A Test" - result = splitter.run(documents=[Document(content=text)]) + result = splitter.run(documents=[Document(id="1", content=text, meta={"key": "value"})]) docs = result["documents"] assert len(docs) == 4 - for w_target, w_split in zip(word_list, docs): - assert w_split.content == w_target + assert docs[0].content == "This" + assert docs[0].meta == {"key": "value", "source_id": "1"} + assert docs[1].content == "Is" + assert docs[1].meta == {"key": "value", "source_id": "1"} + assert docs[2].content == "A" + assert docs[2].meta == {"key": "value", "source_id": "1"} + assert docs[3].content == "Test" + assert docs[3].meta == {"key": "value", "source_id": "1"} def test_split_by_word_with_overlap(self): splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2)