-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add DocumentLanguageClassifier 2.0 (#6037)
* add DocumentLanguageClassifier and tests * reno * fix import, rename DocumentCleaner * mark example usage as python code * add assertions to e2e test * use deserialized document_store * Apply suggestions from code review Co-authored-by: Massimiliano Pippi <[email protected]> * remove from/to_dict * use renamed InMemoryDocumentStore * adapt to Document refactoring * improve docstring * fix test for new Document --------- Co-authored-by: Massimiliano Pippi <[email protected]> Co-authored-by: Stefano Fiorucci <[email protected]> Co-authored-by: anakin87 <[email protected]>
- Loading branch information
1 parent
209e349
commit 29b1fef
Showing
5 changed files
with
222 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import json | ||
|
||
from haystack.preview import Pipeline | ||
from haystack.preview.components.embedders import SentenceTransformersDocumentEmbedder | ||
from haystack.preview.components.file_converters import TextFileToDocument | ||
from haystack.preview.components.preprocessors import TextDocumentSplitter, DocumentCleaner, DocumentLanguageClassifier | ||
from haystack.preview.components.routers import FileTypeRouter | ||
from haystack.preview.components.writers import DocumentWriter | ||
from haystack.preview.document_stores import InMemoryDocumentStore | ||
|
||
|
||
def test_preprocessing_pipeline(tmp_path): | ||
# Create the pipeline and its components | ||
document_store = InMemoryDocumentStore() | ||
preprocessing_pipeline = Pipeline() | ||
preprocessing_pipeline.add_component(instance=FileTypeRouter(mime_types=["text/plain"]), name="file_type_router") | ||
preprocessing_pipeline.add_component(instance=TextFileToDocument(), name="text_file_converter") | ||
preprocessing_pipeline.add_component(instance=DocumentLanguageClassifier(), name="language_classifier") | ||
preprocessing_pipeline.add_component(instance=DocumentCleaner(), name="cleaner") | ||
preprocessing_pipeline.add_component( | ||
instance=TextDocumentSplitter(split_by="sentence", split_length=1), name="splitter" | ||
) | ||
preprocessing_pipeline.add_component( | ||
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"), | ||
name="embedder", | ||
) | ||
preprocessing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="writer") | ||
preprocessing_pipeline.connect("file_type_router.text/plain", "text_file_converter.paths") | ||
preprocessing_pipeline.connect("text_file_converter.documents", "language_classifier.documents") | ||
preprocessing_pipeline.connect("language_classifier.en", "cleaner.documents") | ||
preprocessing_pipeline.connect("cleaner.documents", "splitter.documents") | ||
preprocessing_pipeline.connect("splitter.documents", "embedder.documents") | ||
preprocessing_pipeline.connect("embedder.documents", "writer.documents") | ||
|
||
# Draw the pipeline | ||
preprocessing_pipeline.draw(tmp_path / "test_preprocessing_pipeline.png") | ||
|
||
# Serialize the pipeline to JSON | ||
with open(tmp_path / "test_preprocessing_pipeline.json", "w") as f: | ||
print(json.dumps(preprocessing_pipeline.to_dict(), indent=4)) | ||
json.dump(preprocessing_pipeline.to_dict(), f) | ||
|
||
# Load the pipeline back | ||
with open(tmp_path / "test_preprocessing_pipeline.json", "r") as f: | ||
preprocessing_pipeline = Pipeline.from_dict(json.load(f)) | ||
|
||
# Write a txt file | ||
with open(tmp_path / "test_file_english.txt", "w") as f: | ||
f.write( | ||
"This is an english sentence. There is more to it. It's a long text." | ||
"Spans multiple lines." | ||
"" | ||
"Even contains empty lines. And extra whitespaces." | ||
) | ||
|
||
# Write a txt file | ||
with open(tmp_path / "test_file_german.txt", "w") as f: | ||
f.write("Ein deutscher Satz ohne Verb.") | ||
|
||
# Add two txt files and one non-txt file | ||
paths = [ | ||
tmp_path / "test_file_english.txt", | ||
tmp_path / "test_file_german.txt", | ||
tmp_path / "test_preprocessing_pipeline.json", | ||
] | ||
|
||
result = preprocessing_pipeline.run({"file_type_router": {"sources": paths}}) | ||
|
||
assert result["writer"]["documents_written"] == 6 | ||
filled_document_store = preprocessing_pipeline.get_component("writer").document_store | ||
assert filled_document_store.count_documents() == 6 | ||
|
||
# Check preprocessed texts and mime_types | ||
stored_documents = filled_document_store.filter_documents() | ||
expected_texts = [ | ||
"This is an english sentence.", | ||
" There is more to it.", | ||
" It's a long text.", | ||
"Spans multiple lines.", | ||
"Even contains empty lines.", | ||
" And extra whitespaces.", | ||
] | ||
assert expected_texts == [document.content for document in stored_documents] | ||
assert all(document.mime_type == "text/plain" for document in stored_documents) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from haystack.preview.components.preprocessors.text_document_cleaner import DocumentCleaner | ||
from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter | ||
from haystack.preview.components.preprocessors.document_language_classifier import DocumentLanguageClassifier | ||
from haystack.preview.components.preprocessors.text_language_classifier import TextLanguageClassifier | ||
|
||
__all__ = ["TextDocumentSplitter", "DocumentCleaner", "TextLanguageClassifier"] | ||
__all__ = ["TextDocumentSplitter", "DocumentCleaner", "TextLanguageClassifier", "DocumentLanguageClassifier"] |
81 changes: 81 additions & 0 deletions
81
haystack/preview/components/preprocessors/document_language_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import logging | ||
from typing import List, Dict, Optional | ||
|
||
from haystack.preview import component, Document | ||
from haystack.preview.lazy_imports import LazyImport | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
with LazyImport("Run 'pip install langdetect'") as langdetect_import: | ||
import langdetect | ||
|
||
|
||
@component | ||
class DocumentLanguageClassifier: | ||
""" | ||
Routes documents onto different output connections depending on their language. | ||
This is useful for routing documents to different models in a pipeline depending on their language. | ||
The set of supported languages can be specified. | ||
For routing plain text using the same logic, use the related TextLanguageClassifier component instead. | ||
Example usage within an indexing pipeline, storing in a Document Store | ||
only documents written in English: | ||
```python | ||
document_store = InMemoryDocumentStore() | ||
p = Pipeline() | ||
p.add_component(instance=TextFileToDocument(), name="text_file_converter") | ||
p.add_component(instance=DocumentLanguageClassifier(), name="language_classifier") | ||
p.add_component(instance=DocumentWriter(document_store=document_store), name="writer") | ||
p.connect("text_file_converter.documents", "language_classifier.documents") | ||
p.connect("language_classifier.en", "writer.documents") | ||
``` | ||
""" | ||
|
||
def __init__(self, languages: Optional[List[str]] = None): | ||
""" | ||
:param languages: A list of languages in ISO code, each corresponding to a different output connection | ||
(see [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)). | ||
By default, only ["en"] is supported and Documents of any other language are routed to "unmatched". | ||
""" | ||
langdetect_import.check() | ||
if not languages: | ||
languages = ["en"] | ||
self.languages = languages | ||
component.set_output_types( | ||
self, unmatched=List[Document], **{language: List[Document] for language in languages} | ||
) | ||
|
||
def run(self, documents: List[Document]): | ||
""" | ||
Run the DocumentLanguageClassifier. This method routes the documents to different edges based on their language. | ||
If a Document's text does not match any of the languages specified at initialization, it is routed to | ||
a connection named "unmatched". | ||
:param documents: A list of documents to route to different edges. | ||
""" | ||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): | ||
raise TypeError( | ||
"DocumentLanguageClassifier expects a list of Document as input. " | ||
"In case you want to classify a text, please use the TextLanguageClassifier." | ||
) | ||
|
||
output: Dict[str, List[Document]] = {language: [] for language in self.languages} | ||
output["unmatched"] = [] | ||
|
||
for document in documents: | ||
detected_language = self.detect_language(document) | ||
if detected_language in self.languages: | ||
output[detected_language].append(document) | ||
else: | ||
output["unmatched"].append(document) | ||
|
||
return output | ||
|
||
def detect_language(self, document: Document) -> Optional[str]: | ||
try: | ||
language = langdetect.detect(document.content) | ||
except langdetect.LangDetectException: | ||
logger.warning("Langdetect cannot detect the language of Document with id: %s", document.id) | ||
language = None | ||
return language |
4 changes: 4 additions & 0 deletions
4
releasenotes/notes/document-language-classifier-1ec0b3c4d08989c0.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- | ||
preview: | ||
- | | ||
Added DocumentLanguageClassifier component so that Documents can be routed to different components based on the detected language for example during preprocessing. |
51 changes: 51 additions & 0 deletions
51
test/preview/components/preprocessors/test_document_language_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import logging | ||
import pytest | ||
|
||
from haystack.preview import Document | ||
from haystack.preview.components.preprocessors import DocumentLanguageClassifier | ||
|
||
|
||
class TestDocumentLanguageClassifier: | ||
@pytest.mark.unit | ||
def test_init(self): | ||
component = DocumentLanguageClassifier() | ||
assert component.languages == ["en"] | ||
|
||
@pytest.mark.unit | ||
def test_non_document_input(self): | ||
with pytest.raises(TypeError, match="DocumentLanguageClassifier expects a list of Document as input."): | ||
classifier = DocumentLanguageClassifier() | ||
classifier.run(documents="This is an english sentence.") | ||
|
||
@pytest.mark.unit | ||
def test_single_document(self): | ||
with pytest.raises(TypeError, match="DocumentLanguageClassifier expects a list of Document as input."): | ||
classifier = DocumentLanguageClassifier() | ||
classifier.run(documents=Document(content="This is an english sentence.")) | ||
|
||
@pytest.mark.unit | ||
def test_empty_list(self): | ||
classifier = DocumentLanguageClassifier() | ||
result = classifier.run(documents=[]) | ||
assert result == {"en": [], "unmatched": []} | ||
|
||
@pytest.mark.unit | ||
def test_detect_language(self): | ||
classifier = DocumentLanguageClassifier() | ||
detected_language = classifier.detect_language(Document(content="This is an english sentence.")) | ||
assert detected_language == "en" | ||
|
||
@pytest.mark.unit | ||
def test_route_to_en_and_unmatched(self): | ||
classifier = DocumentLanguageClassifier() | ||
english_document = Document(content="This is an english sentence.") | ||
german_document = Document(content="Ein deutscher Satz ohne Verb.") | ||
result = classifier.run(documents=[english_document, german_document]) | ||
assert result == {"en": [english_document], "unmatched": [german_document]} | ||
|
||
@pytest.mark.unit | ||
def test_warning_if_no_language_detected(self, caplog): | ||
with caplog.at_level(logging.WARNING): | ||
classifier = DocumentLanguageClassifier() | ||
classifier.run(documents=[Document(content=".")]) | ||
assert "Langdetect cannot detect the language of Document with id" in caplog.text |