Skip to content

Commit

Permalink
Merge branch 'main' into nee_ser_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje authored May 14, 2024
2 parents a401935 + a2be90b commit 29cf933
Show file tree
Hide file tree
Showing 15 changed files with 148 additions and 20 deletions.
6 changes: 3 additions & 3 deletions haystack/components/audio/whisper_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber":
:returns:
The deserialized component.
"""
serialized_device = data["init_parameters"]["device"]
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)

init_params = data["init_parameters"]
if init_params["device"] is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedde
:returns:
Deserialized component.
"""
serialized_device = data["init_parameters"]["device"]
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)

init_params = data["init_parameters"]
if init_params["device"] is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
:returns:
Deserialized component.
"""
serialized_device = data["init_parameters"]["device"]
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)

init_params = data["init_parameters"]
if init_params["device"] is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)

Expand Down
5 changes: 3 additions & 2 deletions haystack/components/extractors/named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
Deserialized component.
"""
try:
init_params = data["init_parameters"]
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
init_params = data["init_parameters"]
if init_params["device"] is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
return default_from_dict(cls, data)
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions haystack/components/fetchers/link_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def run(self, urls: List[str]):
for stream_metadata, stream in results: # type: ignore
if stream_metadata is not None and stream is not None:
stream.meta.update(stream_metadata)
stream.mime_type = stream.meta.get("content_type", None)
streams.append(stream)

return {"streams": streams}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker
:returns:
The deserialized component.
"""
serialized_device = data["init_parameters"]["device"]
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)

init_params = data["init_parameters"]
if init_params["device"] is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)

Expand Down
2 changes: 1 addition & 1 deletion haystack/components/routers/file_type_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Uni
if isinstance(source, Path):
mime_type = self._get_mime_type(source)
elif isinstance(source, ByteStream):
mime_type = source.meta.get("content_type", None)
mime_type = source.mime_type
else:
raise ValueError(f"Unsupported data source type: {type(source).__name__}")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Improved MIME type management by directly setting MIME types on ByteStreams, enhancing the overall handling and routing of different file types. This update makes MIME type data more consistently accessible and simplifies the process of working with various document formats.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Updates the from_dict method of SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder, NamedEntityExtractor, SentenceTransformersDiversityRanker and LocalWhisperTranscriber to allow None as a valid value for device when deserializing from a YAML file.
This allows a deserialized pipeline to auto-determine what device to use using the ComponentDevice.resolve_device logic.
11 changes: 11 additions & 0 deletions test/components/audio/test_whisper_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ def test_from_dict(self):
assert transcriber.whisper_params == {}
assert transcriber._model is None

def test_from_dict_none_device(self):
data = {
"type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber",
"init_parameters": {"model": "tiny", "device": None, "whisper_params": {}},
}
transcriber = LocalWhisperTranscriber.from_dict(data)
assert transcriber.model == "tiny"
assert transcriber.device == ComponentDevice.resolve_device(None)
assert transcriber.whisper_params == {}
assert transcriber._model is None

def test_warmup(self):
with patch("haystack.components.audio.whisper_local.whisper") as mocked_whisper:
transcriber = LocalWhisperTranscriber(model="large-v2", device=ComponentDevice.from_str("cpu"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,38 @@ def test_from_dict(self):
assert component.trust_remote_code
assert component.meta_fields_to_embed == ["meta_field"]

def test_from_dict_none_device(self):
init_parameters = {
"model": "model",
"device": None,
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
"embedding_separator": " - ",
"meta_fields_to_embed": ["meta_field"],
"trust_remote_code": True,
}
component = SentenceTransformersDocumentEmbedder.from_dict(
{
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
"init_parameters": init_parameters,
}
)
assert component.model == "model"
assert component.device == ComponentDevice.resolve_device(None)
assert component.token == Secret.from_env_var("ENV_VAR", strict=False)
assert component.prefix == "prefix"
assert component.suffix == "suffix"
assert component.batch_size == 64
assert component.progress_bar is False
assert component.normalize_embeddings is True
assert component.embedding_separator == " - "
assert component.trust_remote_code
assert component.meta_fields_to_embed == ["meta_field"]

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,32 @@ def test_from_dict(self):
assert component.normalize_embeddings is False
assert component.trust_remote_code is False

def test_from_dict_none_device(self):
data = {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "model",
"device": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
assert component.model == "model"
assert component.device == ComponentDevice.resolve_device(None)
assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert component.prefix == ""
assert component.suffix == ""
assert component.batch_size == 32
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False

@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
Expand Down
14 changes: 14 additions & 0 deletions test/components/extractors/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,17 @@ def test_named_entity_extractor_pipeline_serde(tmp_path):
q = Pipeline.load(f)

assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with NamedEntityExtractor failed."


@pytest.mark.unit
def test_named_entity_extractor_serde_none_device():
extractor = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None
)

serde_data = extractor.to_dict()
new_extractor = NamedEntityExtractor.from_dict(serde_data)

assert type(new_extractor._backend) == type(extractor._backend)
assert new_extractor._backend.model_name == extractor._backend.model_name
assert new_extractor._backend.device == extractor._backend.device
31 changes: 31 additions & 0 deletions test/components/rankers/test_sentence_transformers_diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,37 @@ def test_from_dict(self):
assert ranker.meta_fields_to_embed == []
assert ranker.embedding_separator == "\n"

def test_from_dict_none_device(self):
data = {
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
"init_parameters": {
"model": "sentence-transformers/all-MiniLM-L6-v2",
"top_k": 10,
"device": None,
"similarity": "cosine",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"query_prefix": "",
"document_prefix": "",
"query_suffix": "",
"document_suffix": "",
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
}
ranker = SentenceTransformersDiversityRanker.from_dict(data)

assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert ranker.top_k == 10
assert ranker.device == ComponentDevice.resolve_device(None)
assert ranker.similarity == "cosine"
assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert ranker.query_prefix == ""
assert ranker.document_prefix == ""
assert ranker.query_suffix == ""
assert ranker.document_suffix == ""
assert ranker.meta_fields_to_embed == []
assert ranker.embedding_separator == "\n"

def test_to_dict_with_custom_init_parameters(self):
component = SentenceTransformersDiversityRanker(
model="sentence-transformers/msmarco-distilbert-base-v4",
Expand Down
13 changes: 8 additions & 5 deletions test/components/routers/test_file_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_run_with_bytestreams(self, test_files_path):
byte_streams = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())
stream.meta["content_type"] = mime_type
stream.mime_type = mime_type
byte_streams.append(stream)

# add unclassified ByteStream
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_run_with_bytestreams_and_file_paths(self, test_files_path):
byte_stream_sources = []
for path, mime_type in zip(file_paths, mime_types):
stream = ByteStream(path.read_bytes())
stream.meta["content_type"] = mime_type
stream.mime_type = mime_type
byte_stream_sources.append(stream)

mixed_sources = file_paths[:2] + byte_stream_sources[2:]
Expand Down Expand Up @@ -165,9 +165,12 @@ def test_exact_mime_type_matching(self, mock_file):
"""
Test if the component correctly matches mime types exactly, without regex patterns.
"""
txt_stream = ByteStream(io.BytesIO(b"Text file content"), meta={"content_type": "text/plain"})
jpg_stream = ByteStream(io.BytesIO(b"JPEG file content"), meta={"content_type": "image/jpeg"})
mp3_stream = ByteStream(io.BytesIO(b"MP3 file content"), meta={"content_type": "audio/mpeg"})
txt_stream = ByteStream(io.BytesIO(b"Text file content").read())
txt_stream.mime_type = "text/plain"
jpg_stream = ByteStream(io.BytesIO(b"JPEG file content").read())
jpg_stream.mime_type = "image/jpeg"
mp3_stream = ByteStream(io.BytesIO(b"MP3 file content").read())
mp3_stream.mime_type = "audio/mpeg"

byte_streams = [txt_stream, jpg_stream, mp3_stream]

Expand Down

0 comments on commit 29cf933

Please sign in to comment.