Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: SentenceWindowRetriever returns List[Document] with docs ordered by split_idx_start #8590

Merged
merged 10 commits into from
Dec 4, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class InMemoryBM25Retriever:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
document_store: InMemoryDocumentStore,
filters: Optional[Dict[str, Any]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class InMemoryEmbeddingRetriever:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
document_store: InMemoryDocumentStore,
filters: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -143,7 +143,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryEmbeddingRetriever":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
def run( # pylint: disable=too-many-positional-arguments
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def run(self, retrieved_documents: List[Document], window_size: Optional[int] =
}
)
context_text.append(self.merge_documents_text(context_docs))
context_documents.append(context_docs)
context_docs_sorted = sorted(context_docs, key=lambda doc: doc.meta["split_idx_start"])
context_documents.extend(context_docs_sorted)

return {"context_windows": context_text, "context_documents": context_documents}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
upgrade:
- |
The SentenceWindowRetriever output key `context_documents` now outputs a List[Document] containing the retrieved documents together with the context windows ordered by `split_idx_start`.
37 changes: 35 additions & 2 deletions test/components/retrievers/test_sentence_window_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,39 @@ def test_constructor_parameter_does_not_change(self):
retriever.run(retrieved_documents=[Document.from_dict(doc)], window_size=1)
assert retriever.window_size == 5

def test_context_documents_returned_are_ordered_by_split_idx_start(self):
docs = []
accumulated_length = 0
for sent in range(10):
content = f"Sentence {sent}."
docs.append(
Document(
content=content,
meta={
"id": f"doc_{sent}",
"split_idx_start": accumulated_length,
"source_id": "source1",
"split_id": sent,
},
)
)
accumulated_length += len(content)

import random

random.shuffle(docs)

doc_store = InMemoryDocumentStore()
doc_store.write_documents(docs)
retriever = SentenceWindowRetriever(document_store=doc_store, window_size=3)

# run the retriever with a document whose content = "Sentence 4."
result = retriever.run(retrieved_documents=[doc for doc in docs if doc.content == "Sentence 4."])

# assert that the context documents are in the correct order
assert len(result["context_documents"]) == 7
assert [doc.meta["split_idx_start"] for doc in result["context_documents"]] == [11, 22, 33, 44, 55, 66, 77]

@pytest.mark.integration
def test_run_with_pipeline(self):
splitter = DocumentSplitter(split_length=1, split_overlap=0, split_by="sentence")
Expand All @@ -165,13 +198,13 @@ def test_run_with_pipeline(self):
"This is a text with some words. There is a second sentence. And there is also a third sentence. "
"It also contains a fourth sentence. And a fifth sentence."
]
assert len(result["sentence_window_retriever"]["context_documents"][0]) == 5
assert len(result["sentence_window_retriever"]["context_documents"]) == 5

result = pipe.run({"bm25_retriever": {"query": "third"}, "sentence_window_retriever": {"window_size": 1}})
assert result["sentence_window_retriever"]["context_windows"] == [
" There is a second sentence. And there is also a third sentence. It also contains a fourth sentence."
]
assert len(result["sentence_window_retriever"]["context_documents"][0]) == 3
assert len(result["sentence_window_retriever"]["context_documents"]) == 3

@pytest.mark.integration
def test_serialization_deserialization_in_pipeline(self):
Expand Down
Loading