diff --git a/integrations/mistral/examples/indexing_pipeline.py b/integrations/mistral/examples/indexing_pipeline.py index 0db67a7a1..0329fab8c 100644 --- a/integrations/mistral/examples/indexing_pipeline.py +++ b/integrations/mistral/examples/indexing_pipeline.py @@ -1,14 +1,13 @@ # To run this example, you will need an to set a `MISTRAL_API_KEY` environment variable. # This example streams chat replies to the console. -from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder - from haystack import Pipeline -from haystack.components.fetchers import LinkContentFetcher from haystack.components.converters import HTMLToDocument +from haystack.components.fetchers import LinkContentFetcher from haystack.components.preprocessors import DocumentSplitter -from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.components.writers import DocumentWriter +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder document_store = InMemoryDocumentStore() fetcher = LinkContentFetcher() @@ -30,4 +29,4 @@ indexing.connect("chunker", "embedder") indexing.connect("embedder", "writer") -indexing.run(data={"fetcher": {"urls": ["https://mistral.ai/news/la-plateforme/"]}}) \ No newline at end of file +indexing.run(data={"fetcher": {"urls": ["https://mistral.ai/news/la-plateforme/"]}}) diff --git a/integrations/mistral/examples/streaming_chat_with_rag.py b/integrations/mistral/examples/streaming_chat_with_rag.py index df24e1f6f..2e3eeee5a 100644 --- a/integrations/mistral/examples/streaming_chat_with_rag.py +++ b/integrations/mistral/examples/streaming_chat_with_rag.py @@ -1,20 +1,19 @@ # To run this example, you will need an to set a `MISTRAL_API_KEY` environment variable. # This example streams chat replies to the console. -from haystack_integrations.components.generators.mistral import MistralChatGenerator -from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder -from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder - from haystack import Pipeline -from haystack.dataclasses import ChatMessage -from haystack.components.generators.utils import print_streaming_chunk -from haystack.components.fetchers import LinkContentFetcher +from haystack.components.builders import DynamicChatPromptBuilder from haystack.components.converters import HTMLToDocument +from haystack.components.fetchers import LinkContentFetcher +from haystack.components.generators.utils import print_streaming_chunk from haystack.components.preprocessors import DocumentSplitter from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever -from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.components.writers import DocumentWriter -from haystack.components.builders import DynamicChatPromptBuilder +from haystack.dataclasses import ChatMessage +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder +from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder +from haystack_integrations.components.generators.mistral import MistralChatGenerator document_store = InMemoryDocumentStore() fetcher = LinkContentFetcher() @@ -58,8 +57,10 @@ question = "What are the available models?" -result = rag_pipeline.run({ "text_embedder": {"text": question}, - "prompt_builder": {"template_variables": {"query": question}, - "prompt_source": messages}, - "llm": {"generation_kwargs": {"max_tokens": 165}} - }) \ No newline at end of file +result = rag_pipeline.run( + { + "text_embedder": {"text": question}, + "prompt_builder": {"template_variables": {"query": question}, "prompt_source": messages}, + "llm": {"generation_kwargs": {"max_tokens": 165}}, + } +) diff --git a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/__init__.py b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/__init__.py index e69de29bb..cbbf4e60f 100644 --- a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/__init__.py +++ b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .document_embedder import MistralDocumentEmbedder +from .text_embedder import MistralTextEmbedder + +__all__ = ["MistralDocumentEmbedder", "MistralTextEmbedder"] \ No newline at end of file diff --git a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/document_embedder.py b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/document_embedder.py index 5ea979c8e..40f85b97b 100644 --- a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/document_embedder.py +++ b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/document_embedder.py @@ -4,8 +4,8 @@ from typing import List, Optional from haystack import component -from haystack.utils.auth import Secret from haystack.components.embedders import OpenAIDocumentEmbedder +from haystack.utils.auth import Secret @component @@ -48,9 +48,10 @@ def __init__( Create a MistralDocumentEmbedder component. :param api_key: The Mistral API key. :param model: The name of the model to use. - :param dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. + :param dimensions: Not yet supported with `mistral-embed`. Currently this model outputs `1024` dimensions. + For more info, refer to the Mistral [docs](https://docs.mistral.ai/platform/endpoints/#embedding-models) :param api_base_url: The Mistral API Base url, defaults to None. For more details, see Mistral [docs](https://docs.mistral.ai/api/). - :param organization: The Organization ID, defaults to `None`. + :param organization: Not yet supported with Mistral, defaults to `None`. :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. :param batch_size: Number of Documents to encode at once. @@ -69,4 +70,5 @@ def __init__( batch_size, progress_bar, meta_fields_to_embed, - embedding_separator) \ No newline at end of file + embedding_separator, + ) \ No newline at end of file diff --git a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/text_embedder.py b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/text_embedder.py index 2eb61c1d7..0e53e01c4 100644 --- a/integrations/mistral/src/haystack_integrations/components/embedders/mistral/text_embedder.py +++ b/integrations/mistral/src/haystack_integrations/components/embedders/mistral/text_embedder.py @@ -4,30 +4,31 @@ from typing import Optional from haystack import component -from haystack.utils.auth import Secret from haystack.components.embedders import OpenAITextEmbedder +from haystack.utils.auth import Secret @component class MistralTextEmbedder(OpenAITextEmbedder): """ - A component for embedding strings using Mistral models. + A component for embedding strings using Mistral models. - Usage example: - ```python - from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder + Usage example: + ```python + from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder - text_to_embed = "I love pizza!" + text_to_embed = "I love pizza!" - text_embedder = MistralTextEmbedder() + text_embedder = MistralTextEmbedder() - print(text_embedder.run(text_to_embed)) + print(text_embedder.run(text_to_embed)) - # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], - # 'meta': {'model': 'text-embedding-ada-002-v2', - # 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} - ``` + # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], + # 'meta': {'model': 'text-embedding-ada-002-v2', + # 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} + ``` """ + def __init__( self, api_key: Secret = Secret.from_env_var("MISTRAL_API_KEY"), @@ -44,15 +45,12 @@ def __init__( :param api_key: The Misttal API key. :param model: The name of the Mistral embedding models to be used. :param dimensions: Not yet supported with Mistral embedding models - :param organization: The Organization ID, defaults to `None`. - :param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`. For more details, see Mistral [docs](https://docs.mistral.ai/api/). + :param organization: The Organization ID, defaults to `None`. + :param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`. + For more details, see Mistral [docs](https://docs.mistral.ai/api/). :param prefix: Not yet supported with Mistral embedding models :param suffix: Not yet supported with Mistral embedding models """ - super(MistralTextEmbedder, self).__init__(api_key, - model, - dimensions, - api_base_url, - organization, - prefix, - suffix) \ No newline at end of file + super(MistralTextEmbedder, self).__init__( + api_key, model, dimensions, api_base_url, organization, prefix, suffix + ) diff --git a/integrations/mistral/src/haystack_integrations/components/generators/mistral/__init__.py b/integrations/mistral/src/haystack_integrations/components/generators/mistral/__init__.py index 831b5991e..a6320494a 100644 --- a/integrations/mistral/src/haystack_integrations/components/generators/mistral/__init__.py +++ b/integrations/mistral/src/haystack_integrations/components/generators/mistral/__init__.py @@ -1,3 +1,3 @@ from .chat.chat_generator import MistralChatGenerator -__all__ = ["MistralChatGenerator"] \ No newline at end of file +__all__ = ["MistralChatGenerator"] diff --git a/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py b/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py index 8a2ec1188..f2a941733 100644 --- a/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py +++ b/integrations/mistral/src/haystack_integrations/components/generators/mistral/chat/chat_generator.py @@ -1,18 +1,18 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -import asyncio from typing import Any, Callable, Dict, Optional from haystack import component -from haystack.dataclasses import StreamingChunk, ChatMessage -from haystack.utils.auth import Secret from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret + @component class MistralChatGenerator(OpenAIChatGenerator): """ - Enables text generation using Mistral's large language models (LLMs). + Enables text generation using Mistral's large language models (LLMs). Currently supports `mistral-tiny`, `mistral-small` and `mistral-medium` models accessed through the chat completions API endpoint. @@ -20,7 +20,7 @@ class MistralChatGenerator(OpenAIChatGenerator): directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs` parameter in `run` method. - For more details on the parameters supported by the Mistral API, refer to the + For more details on the parameters supported by the Mistral API, refer to the [Mistral API Docs](https://docs.mistral.ai/api/). ```python @@ -36,7 +36,7 @@ class MistralChatGenerator(OpenAIChatGenerator): >>{'replies': [ChatMessage(content='Natural Language Processing (NLP) is a branch of artificial intelligence >>that focuses on enabling computers to understand, interpret, and generate human language in a way that is >>meaningful and useful.', role=, name=None, - >>meta={'model': 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', + >>meta={'model': 'mistral-tiny', 'index': 0, 'finish_reason': 'stop', >>'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]} ``` @@ -48,8 +48,9 @@ class MistralChatGenerator(OpenAIChatGenerator): Input and Output Format: - **ChatMessage Format**: This component uses the ChatMessage format for structuring both input and output, - ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the - ChatMessage format can be found at: https://github.com/openai/openai-python/blob/main/chatml.md. + ensuring coherent and contextually relevant responses in chat-based text generation scenarios. + Details on the ChatMessage format can be found at: https://github.com/openai/openai-python/blob/main/chatml.md. + Note that the Mistral API does not accept `system` messages yet. You can use `user` and `assistant` messages. """ def __init__( @@ -69,7 +70,8 @@ def __init__( :param model: The name of the Mistral chat completion model to use. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. - :param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`. For more details, see Mistral [docs](https://docs.mistral.ai/api/). + :param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`. + For more details, see Mistral [docs](https://docs.mistral.ai/api/). :param organization: Not yet supported with Mistral chat completion models :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to the Mistrak endpoint. See [Mistral API docs](https://docs.mistral.ai/api/t) for @@ -81,15 +83,12 @@ def __init__( - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. - - `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent + - `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. - `stop`: One or more sequences after which the LLM should stop generating tokens. - `safe_prompt`: Whether to inject a safety prompt before all conversations. - - `random_seed`: The seed to use for random sampling. If set, different calls will generate deterministic results. - """ - super(MistralChatGenerator, self).__init__(api_key, - model, - streaming_callback, - api_base_url, - organization, - generation_kwargs) \ No newline at end of file + - `random_seed`: The seed to use for random sampling. + """ + super().__init__( + api_key, model, streaming_callback, api_base_url, organization, generation_kwargs + ) diff --git a/integrations/mistral/tests/test_mistral_document_embedder.py b/integrations/mistral/tests/test_mistral_document_embedder.py index 1cdfc6272..0f66df53c 100644 --- a/integrations/mistral/tests/test_mistral_document_embedder.py +++ b/integrations/mistral/tests/test_mistral_document_embedder.py @@ -24,117 +24,109 @@ def test_init_default(self): assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" - # def test_init_with_parameters(self): - # embedder = MistralDocumentEmbedder( - # api_key=Secret.from_token("test-api-key"), - # model="embed-multilingual-v2.0", - # input_type="search_query", - # api_base_url="https://custom-api-base-url.com", - # truncate="START", - # use_async_client=True, - # max_retries=5, - # timeout=60, - # batch_size=64, - # progress_bar=False, - # meta_fields_to_embed=["test_field"], - # embedding_separator="-", - # ) - # assert embedder.api_key == Secret.from_token("test-api-key") - # assert embedder.model == "embed-multilingual-v2.0" - # assert embedder.input_type == "search_query" - # assert embedder.api_base_url == "https://custom-api-base-url.com" - # assert embedder.truncate == "START" - # assert embedder.use_async_client is True - # assert embedder.max_retries == 5 - # assert embedder.timeout == 60 - # assert embedder.batch_size == 64 - # assert embedder.progress_bar is False - # assert embedder.meta_fields_to_embed == ["test_field"] - # assert embedder.embedding_separator == "-" + def test_init_with_parameters(self): + embedder = MistralDocumentEmbedder( + api_key=Secret.from_token("test-api-key"), + model="mistral-embed-v2", + api_base_url="https://custom-api-base-url.com", + prefix="START", + suffix="END", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator="-", + ) + assert embedder.api_key == Secret.from_token("test-api-key") + assert embedder.model == "mistral-embed-v2" + assert embedder.api_base_url == "https://custom-api-base-url.com" + assert embedder.prefix == "START" + assert embedder.suffix == "END" + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == "-" - # def test_to_dict(self): - # embedder_component = MistralDocumentEmbedder() - # component_dict = embedder_component.to_dict() - # assert component_dict == { - # "type": "haystack_integrations.components.embedders.cohere.document_embedder.CohereDocumentEmbedder", - # "init_parameters": { - # "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, - # "model": "embed-english-v2.0", - # "input_type": "search_document", - # "api_base_url": COHERE_API_URL, - # "truncate": "END", - # "use_async_client": False, - # "max_retries": 3, - # "timeout": 120, - # "batch_size": 32, - # "progress_bar": True, - # "meta_fields_to_embed": [], - # "embedding_separator": "\n", - # }, - # } + def test_to_dict(self): + embedder_component = MistralDocumentEmbedder() + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "haystack_integrations.components.embedders.mistral.document_embedder.MistralDocumentEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, + "model": "mistral-embed", + "dimensions": None, + "api_base_url": "https://api.mistral.ai/v1", + "organization": None, + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + }, + } - # def test_to_dict_with_custom_init_parameters(self): - # embedder_component = CohereDocumentEmbedder( - # api_key=Secret.from_env_var("ENV_VAR", strict=False), - # model="embed-multilingual-v2.0", - # input_type="search_query", - # api_base_url="https://custom-api-base-url.com", - # truncate="START", - # use_async_client=True, - # max_retries=5, - # timeout=60, - # batch_size=64, - # progress_bar=False, - # meta_fields_to_embed=["text_field"], - # embedding_separator="-", - # ) - # component_dict = embedder_component.to_dict() - # assert component_dict == { - # "type": "haystack_integrations.components.embedders.cohere.document_embedder.CohereDocumentEmbedder", - # "init_parameters": { - # "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, - # "model": "embed-multilingual-v2.0", - # "input_type": "search_query", - # "api_base_url": "https://custom-api-base-url.com", - # "truncate": "START", - # "use_async_client": True, - # "max_retries": 5, - # "timeout": 60, - # "batch_size": 64, - # "progress_bar": False, - # "meta_fields_to_embed": ["text_field"], - # "embedding_separator": "-", - # }, - # } + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + monkeypatch.setenv("ENV_VAR", "test-secret-key") + embedder = MistralDocumentEmbedder( + api_key=Secret.from_env_var("ENV_VAR", strict=False), + model="mistral-embed-v2", + api_base_url="https://custom-api-base-url.com", + prefix="START", + suffix="END", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator="-", + ) + component_dict = embedder.to_dict() + assert component_dict == { + "type": "haystack_integrations.components.embedders.mistral.document_embedder.MistralDocumentEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "model": "mistral-embed-v2", + "dimensions": None, + "api_base_url": "https://custom-api-base-url.com", + "organization": None, + "prefix": "START", + "suffix": "END", + "batch_size": 64, + "progress_bar": False, + "meta_fields_to_embed": ["test_field"], + "embedding_separator": "-", + }, + } - # @pytest.mark.skipif( - # not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), - # reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", - # ) - # @pytest.mark.integration - # def test_run(self): - # embedder = CohereDocumentEmbedder() + @pytest.mark.skipif( + not os.environ.get("MISTRAL_API_KEY", None), + reason="Export an env var called MISTRAL_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_run(self): + embedder = MistralDocumentEmbedder() + + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] - # docs = [ - # Document(content="I love cheese", meta={"topic": "Cuisine"}), - # Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), - # ] + result = embedder.run(docs) + docs_with_embeddings = result["documents"] - # result = embedder.run(docs) - # docs_with_embeddings = result["documents"] + assert isinstance(docs_with_embeddings, list) + assert len(docs_with_embeddings) == len(docs) + for doc in docs_with_embeddings: + assert isinstance(doc.embedding, list) + assert isinstance(doc.embedding[0], float) - # assert isinstance(docs_with_embeddings, list) - # assert len(docs_with_embeddings) == len(docs) - # for doc in docs_with_embeddings: - # assert isinstance(doc.embedding, list) - # assert isinstance(doc.embedding[0], float) + def test_run_wrong_input_format(self): + embedder = MistralDocumentEmbedder(api_key=Secret.from_token("test-api-key")) - # def test_run_wrong_input_format(self): - # embedder = CohereDocumentEmbedder(api_key=Secret.from_token("test-api-key")) + match_error_msg = "OpenAIDocumentEmbedder expects a list of Documents as input.In case you want to embed a string, please use the OpenAITextEmbedder." - # with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): - # embedder.run(documents="text") - # with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): - # embedder.run(documents=[1, 2, 3]) + with pytest.raises(TypeError, match=match_error_msg): + embedder.run(documents="text") + with pytest.raises(TypeError, match=match_error_msg): + embedder.run(documents=[1, 2, 3]) - # assert embedder.run(documents=[]) == {"documents": [], "meta": {}} + assert embedder.run(documents=[]) == {"documents": [], "meta": {}} diff --git a/integrations/mistral/tests/test_mistral_text_embedder.py b/integrations/mistral/tests/test_mistral_text_embedder.py new file mode 100644 index 000000000..ca7d20243 --- /dev/null +++ b/integrations/mistral/tests/test_mistral_text_embedder.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +from haystack import Document +from haystack.utils import Secret +from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder + +pytestmark = pytest.mark.embedders + + +class TestMistralTextEmbedder: + def test_init_default(self): + embedder = MistralTextEmbedder() + assert embedder.api_key == Secret.from_env_var(["MISTRAL_API_KEY"]) + assert embedder.model == "mistral-embed" + assert embedder.prefix == "" + assert embedder.suffix == "" + + def test_init_with_parameters(self): + embedder = MistralTextEmbedder( + api_key=Secret.from_token("test-api-key"), + model="mistral-embed-v2", + prefix="START", + suffix="END", + ) + assert embedder.api_key == Secret.from_token("test-api-key") + assert embedder.model == "mistral-embed-v2" + assert embedder.prefix == "START" + assert embedder.suffix == "END" + + def test_to_dict(self): + embedder_component = MistralTextEmbedder() + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "haystack_integrations.components.embedders.mistral.text_embedder.MistralTextEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["MISTRAL_API_KEY"], "strict": True, "type": "env_var"}, + "model": "mistral-embed", + "dimensions": None, + "organization": None, + "prefix": "", + "suffix": "", + }, + } + + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + monkeypatch.setenv("ENV_VAR", "test-secret-key") + embedder = MistralTextEmbedder( + api_key=Secret.from_env_var("ENV_VAR", strict=False), + model="mistral-embed-v2", + api_base_url="https://custom-api-base-url.com", + prefix="START", + suffix="END", + ) + component_dict = embedder.to_dict() + assert component_dict == { + "type": "haystack_integrations.components.embedders.mistral.text_embedder.MistralTextEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "model": "mistral-embed-v2", + "dimensions": None, + "organization": None, + "prefix": "START", + "suffix": "END", + }, + } + + @pytest.mark.skipif( + not os.environ.get("MISTRAL_API_KEY", None), + reason="Export an env var called MISTRAL_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_run(self): + embedder = MistralTextEmbedder() + text = "The food was delicious" + result = embedder.run(text) + assert all(isinstance(x, float) for x in result["embedding"]) + + + def test_run_wrong_input_format(self): + embedder = MistralTextEmbedder(api_key=Secret.from_token("test-api-key")) + list_integers_input = ["text_snippet_1", "text_snippet_2"] + match_error_msg = "OpenAITextEmbedder expects a string as an input.In case you want to embed a list of Documents, please use the OpenAIDocumentEmbedder." + with pytest.raises(TypeError, match=match_error_msg): + embedder.run(text=list_integers_input) \ No newline at end of file