From 811b93db918f39aa81b48faac0b1622e0922b81d Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 13 May 2024 19:44:02 +0200 Subject: [PATCH 1/2] feat: Set ByteStream's mime_type attribute for web based resources (#7681) --- haystack/components/fetchers/link_content.py | 1 + haystack/components/routers/file_type_router.py | 2 +- ...nhanced-mime-type-handling-182fb64a0f5fb852.yaml | 4 ++++ test/components/routers/test_file_router.py | 13 ++++++++----- 4 files changed, 14 insertions(+), 6 deletions(-) create mode 100644 releasenotes/notes/enhanced-mime-type-handling-182fb64a0f5fb852.yaml diff --git a/haystack/components/fetchers/link_content.py b/haystack/components/fetchers/link_content.py index 00cbdeb667..0d86ff8529 100644 --- a/haystack/components/fetchers/link_content.py +++ b/haystack/components/fetchers/link_content.py @@ -151,6 +151,7 @@ def run(self, urls: List[str]): for stream_metadata, stream in results: # type: ignore if stream_metadata is not None and stream is not None: stream.meta.update(stream_metadata) + stream.mime_type = stream.meta.get("content_type", None) streams.append(stream) return {"streams": streams} diff --git a/haystack/components/routers/file_type_router.py b/haystack/components/routers/file_type_router.py index 08c615c7e1..fdf8830c9f 100644 --- a/haystack/components/routers/file_type_router.py +++ b/haystack/components/routers/file_type_router.py @@ -90,7 +90,7 @@ def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Uni if isinstance(source, Path): mime_type = self._get_mime_type(source) elif isinstance(source, ByteStream): - mime_type = source.meta.get("content_type", None) + mime_type = source.mime_type else: raise ValueError(f"Unsupported data source type: {type(source).__name__}") diff --git a/releasenotes/notes/enhanced-mime-type-handling-182fb64a0f5fb852.yaml b/releasenotes/notes/enhanced-mime-type-handling-182fb64a0f5fb852.yaml new file mode 100644 index 0000000000..c0e7d07445 --- /dev/null +++ b/releasenotes/notes/enhanced-mime-type-handling-182fb64a0f5fb852.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Improved MIME type management by directly setting MIME types on ByteStreams, enhancing the overall handling and routing of different file types. This update makes MIME type data more consistently accessible and simplifies the process of working with various document formats. diff --git a/test/components/routers/test_file_router.py b/test/components/routers/test_file_router.py index b64be66336..32d1e99dd1 100644 --- a/test/components/routers/test_file_router.py +++ b/test/components/routers/test_file_router.py @@ -50,7 +50,7 @@ def test_run_with_bytestreams(self, test_files_path): byte_streams = [] for path, mime_type in zip(file_paths, mime_types): stream = ByteStream(path.read_bytes()) - stream.meta["content_type"] = mime_type + stream.mime_type = mime_type byte_streams.append(stream) # add unclassified ByteStream @@ -81,7 +81,7 @@ def test_run_with_bytestreams_and_file_paths(self, test_files_path): byte_stream_sources = [] for path, mime_type in zip(file_paths, mime_types): stream = ByteStream(path.read_bytes()) - stream.meta["content_type"] = mime_type + stream.mime_type = mime_type byte_stream_sources.append(stream) mixed_sources = file_paths[:2] + byte_stream_sources[2:] @@ -165,9 +165,12 @@ def test_exact_mime_type_matching(self, mock_file): """ Test if the component correctly matches mime types exactly, without regex patterns. """ - txt_stream = ByteStream(io.BytesIO(b"Text file content"), meta={"content_type": "text/plain"}) - jpg_stream = ByteStream(io.BytesIO(b"JPEG file content"), meta={"content_type": "image/jpeg"}) - mp3_stream = ByteStream(io.BytesIO(b"MP3 file content"), meta={"content_type": "audio/mpeg"}) + txt_stream = ByteStream(io.BytesIO(b"Text file content").read()) + txt_stream.mime_type = "text/plain" + jpg_stream = ByteStream(io.BytesIO(b"JPEG file content").read()) + jpg_stream.mime_type = "image/jpeg" + mp3_stream = ByteStream(io.BytesIO(b"MP3 file content").read()) + mp3_stream.mime_type = "audio/mpeg" byte_streams = [txt_stream, jpg_stream, mp3_stream] From a2be90b95a402f71ef158b38523f5249dc94ffef Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 14 May 2024 08:36:14 +0200 Subject: [PATCH 2/2] fix: Update device deserialization for components that use local models (#7686) * fix: Update device deserializtion for SentenceTransformersTextEmbedder * Add unit test * Fix unit test * Make same change to doc embedder * Add release notes * Add same change to Diversity Ranker and Named Entity Extractor * Add unit test * Add the same for whisper local * Update release notes --- haystack/components/audio/whisper_local.py | 6 ++-- ...sentence_transformers_document_embedder.py | 6 ++-- .../sentence_transformers_text_embedder.py | 6 ++-- .../extractors/named_entity_extractor.py | 3 +- .../sentence_transformers_diversity.py | 6 ++-- ...lization-st-embedder-c4efad96dd3869d5.yaml | 5 +++ test/components/audio/test_whisper_local.py | 11 +++++++ ...sentence_transformers_document_embedder.py | 32 +++++++++++++++++++ ...est_sentence_transformers_text_embedder.py | 26 +++++++++++++++ .../extractors/test_named_entity_extractor.py | 14 ++++++++ .../test_sentence_transformers_diversity.py | 31 ++++++++++++++++++ 11 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 releasenotes/notes/fix-device-deserialization-st-embedder-c4efad96dd3869d5.yaml diff --git a/haystack/components/audio/whisper_local.py b/haystack/components/audio/whisper_local.py index f60c76f57b..5a96f40e48 100644 --- a/haystack/components/audio/whisper_local.py +++ b/haystack/components/audio/whisper_local.py @@ -90,9 +90,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber": :returns: The deserialized component. """ - serialized_device = data["init_parameters"]["device"] - data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) - + init_params = data["init_parameters"] + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index 010e4938d0..40550d5729 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -125,9 +125,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedde :returns: Deserialized component. """ - serialized_device = data["init_parameters"]["device"] - data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) - + init_params = data["init_parameters"] + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) return default_from_dict(cls, data) diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 5907fbb27c..0457f8815d 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -115,9 +115,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder": :returns: Deserialized component. """ - serialized_device = data["init_parameters"]["device"] - data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) - + init_params = data["init_parameters"] + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) return default_from_dict(cls, data) diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 649b7dc59f..a8c6c15bc4 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -215,7 +215,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor": """ try: init_params = data["init_parameters"] - init_params["device"] = ComponentDevice.from_dict(init_params["device"]) + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) return default_from_dict(cls, data) except Exception as e: raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index 43f5a2417a..c1a216533a 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -141,9 +141,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker :returns: The deserialized component. """ - serialized_device = data["init_parameters"]["device"] - data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) - + init_params = data["init_parameters"] + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) return default_from_dict(cls, data) diff --git a/releasenotes/notes/fix-device-deserialization-st-embedder-c4efad96dd3869d5.yaml b/releasenotes/notes/fix-device-deserialization-st-embedder-c4efad96dd3869d5.yaml new file mode 100644 index 0000000000..6bb0a4d2b9 --- /dev/null +++ b/releasenotes/notes/fix-device-deserialization-st-embedder-c4efad96dd3869d5.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Updates the from_dict method of SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder, NamedEntityExtractor, SentenceTransformersDiversityRanker and LocalWhisperTranscriber to allow None as a valid value for device when deserializing from a YAML file. + This allows a deserialized pipeline to auto-determine what device to use using the ComponentDevice.resolve_device logic. diff --git a/test/components/audio/test_whisper_local.py b/test/components/audio/test_whisper_local.py index 6a6c3a8f23..6cbd43575a 100644 --- a/test/components/audio/test_whisper_local.py +++ b/test/components/audio/test_whisper_local.py @@ -72,6 +72,17 @@ def test_from_dict(self): assert transcriber.whisper_params == {} assert transcriber._model is None + def test_from_dict_none_device(self): + data = { + "type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber", + "init_parameters": {"model": "tiny", "device": None, "whisper_params": {}}, + } + transcriber = LocalWhisperTranscriber.from_dict(data) + assert transcriber.model == "tiny" + assert transcriber.device == ComponentDevice.resolve_device(None) + assert transcriber.whisper_params == {} + assert transcriber._model is None + def test_warmup(self): with patch("haystack.components.audio.whisper_local.whisper") as mocked_whisper: transcriber = LocalWhisperTranscriber(model="large-v2", device=ComponentDevice.from_str("cpu")) diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index e9fc3e3c6e..75564188af 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -137,6 +137,38 @@ def test_from_dict(self): assert component.trust_remote_code assert component.meta_fields_to_embed == ["meta_field"] + def test_from_dict_none_device(self): + init_parameters = { + "model": "model", + "device": None, + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "batch_size": 64, + "progress_bar": False, + "normalize_embeddings": True, + "embedding_separator": " - ", + "meta_fields_to_embed": ["meta_field"], + "trust_remote_code": True, + } + component = SentenceTransformersDocumentEmbedder.from_dict( + { + "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", + "init_parameters": init_parameters, + } + ) + assert component.model == "model" + assert component.device == ComponentDevice.resolve_device(None) + assert component.token == Secret.from_env_var("ENV_VAR", strict=False) + assert component.prefix == "prefix" + assert component.suffix == "suffix" + assert component.batch_size == 64 + assert component.progress_bar is False + assert component.normalize_embeddings is True + assert component.embedding_separator == " - " + assert component.trust_remote_code + assert component.meta_fields_to_embed == ["meta_field"] + @patch( "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" ) diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 433a512524..ec9234b6c9 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -122,6 +122,32 @@ def test_from_dict(self): assert component.normalize_embeddings is False assert component.trust_remote_code is False + def test_from_dict_none_device(self): + data = { + "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", + "init_parameters": { + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "model": "model", + "device": None, + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "normalize_embeddings": False, + "trust_remote_code": False, + }, + } + component = SentenceTransformersTextEmbedder.from_dict(data) + assert component.model == "model" + assert component.device == ComponentDevice.resolve_device(None) + assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert component.prefix == "" + assert component.suffix == "" + assert component.batch_size == 32 + assert component.progress_bar is True + assert component.normalize_embeddings is False + assert component.trust_remote_code is False + @patch( "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" ) diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index 140752f261..d47ae69dd2 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -40,3 +40,17 @@ def test_named_entity_extractor_serde(): with pytest.raises(DeserializationError, match=r"Couldn't deserialize"): serde_data["init_parameters"].pop("backend") _ = NamedEntityExtractor.from_dict(serde_data) + + +@pytest.mark.unit +def test_named_entity_extractor_serde_none_device(): + extractor = NamedEntityExtractor( + backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None + ) + + serde_data = extractor.to_dict() + new_extractor = NamedEntityExtractor.from_dict(serde_data) + + assert type(new_extractor._backend) == type(extractor._backend) + assert new_extractor._backend.model_name == extractor._backend.model_name + assert new_extractor._backend.device == extractor._backend.device diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index a7fade57f9..b4885d3278 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -113,6 +113,37 @@ def test_from_dict(self): assert ranker.meta_fields_to_embed == [] assert ranker.embedding_separator == "\n" + def test_from_dict_none_device(self): + data = { + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", + "init_parameters": { + "model": "sentence-transformers/all-MiniLM-L6-v2", + "top_k": 10, + "device": None, + "similarity": "cosine", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "query_prefix": "", + "document_prefix": "", + "query_suffix": "", + "document_suffix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + ranker = SentenceTransformersDiversityRanker.from_dict(data) + + assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" + assert ranker.top_k == 10 + assert ranker.device == ComponentDevice.resolve_device(None) + assert ranker.similarity == "cosine" + assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert ranker.query_prefix == "" + assert ranker.document_prefix == "" + assert ranker.query_suffix == "" + assert ranker.document_suffix == "" + assert ranker.meta_fields_to_embed == [] + assert ranker.embedding_separator == "\n" + def test_to_dict_with_custom_init_parameters(self): component = SentenceTransformersDiversityRanker( model="sentence-transformers/msmarco-distilbert-base-v4",