Skip to content

Commit

Permalink
adding tests and some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
TuanaCelik committed Feb 15, 2024
1 parent 83b97cf commit 266f30e
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 168 deletions.
9 changes: 4 additions & 5 deletions integrations/mistral/examples/indexing_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# To run this example, you will need an to set a `MISTRAL_API_KEY` environment variable.
# This example streams chat replies to the console.

from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder

from haystack import Pipeline
from haystack.components.fetchers import LinkContentFetcher
from haystack.components.converters import HTMLToDocument
from haystack.components.fetchers import LinkContentFetcher
from haystack.components.preprocessors import DocumentSplitter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder

document_store = InMemoryDocumentStore()
fetcher = LinkContentFetcher()
Expand All @@ -30,4 +29,4 @@
indexing.connect("chunker", "embedder")
indexing.connect("embedder", "writer")

indexing.run(data={"fetcher": {"urls": ["https://mistral.ai/news/la-plateforme/"]}})
indexing.run(data={"fetcher": {"urls": ["https://mistral.ai/news/la-plateforme/"]}})
29 changes: 15 additions & 14 deletions integrations/mistral/examples/streaming_chat_with_rag.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# To run this example, you will need an to set a `MISTRAL_API_KEY` environment variable.
# This example streams chat replies to the console.

from haystack_integrations.components.generators.mistral import MistralChatGenerator
from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder
from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder

from haystack import Pipeline
from haystack.dataclasses import ChatMessage
from haystack.components.generators.utils import print_streaming_chunk
from haystack.components.fetchers import LinkContentFetcher
from haystack.components.builders import DynamicChatPromptBuilder
from haystack.components.converters import HTMLToDocument
from haystack.components.fetchers import LinkContentFetcher
from haystack.components.generators.utils import print_streaming_chunk
from haystack.components.preprocessors import DocumentSplitter
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.writers import DocumentWriter
from haystack.components.builders import DynamicChatPromptBuilder
from haystack.dataclasses import ChatMessage
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack_integrations.components.embedders.mistral.document_embedder import MistralDocumentEmbedder
from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder
from haystack_integrations.components.generators.mistral import MistralChatGenerator

document_store = InMemoryDocumentStore()
fetcher = LinkContentFetcher()
Expand Down Expand Up @@ -58,8 +57,10 @@

question = "What are the available models?"

result = rag_pipeline.run({ "text_embedder": {"text": question},
"prompt_builder": {"template_variables": {"query": question},
"prompt_source": messages},
"llm": {"generation_kwargs": {"max_tokens": 165}}
})
result = rag_pipeline.run(
{
"text_embedder": {"text": question},
"prompt_builder": {"template_variables": {"query": question}, "prompt_source": messages},
"llm": {"generation_kwargs": {"max_tokens": 165}},
}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .document_embedder import MistralDocumentEmbedder
from .text_embedder import MistralTextEmbedder

__all__ = ["MistralDocumentEmbedder", "MistralTextEmbedder"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import List, Optional

from haystack import component
from haystack.utils.auth import Secret
from haystack.components.embedders import OpenAIDocumentEmbedder
from haystack.utils.auth import Secret


@component
Expand Down Expand Up @@ -48,9 +48,10 @@ def __init__(
Create a MistralDocumentEmbedder component.
:param api_key: The Mistral API key.
:param model: The name of the model to use.
:param dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param dimensions: Not yet supported with `mistral-embed`. Currently this model outputs `1024` dimensions.
For more info, refer to the Mistral [docs](https://docs.mistral.ai/platform/endpoints/#embedding-models)
:param api_base_url: The Mistral API Base url, defaults to None. For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param organization: The Organization ID, defaults to `None`.
:param organization: Not yet supported with Mistral, defaults to `None`.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of Documents to encode at once.
Expand All @@ -69,4 +70,5 @@ def __init__(
batch_size,
progress_bar,
meta_fields_to_embed,
embedding_separator)
embedding_separator,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,31 @@
from typing import Optional

from haystack import component
from haystack.utils.auth import Secret
from haystack.components.embedders import OpenAITextEmbedder
from haystack.utils.auth import Secret


@component
class MistralTextEmbedder(OpenAITextEmbedder):
"""
A component for embedding strings using Mistral models.
A component for embedding strings using Mistral models.
Usage example:
```python
from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder
Usage example:
```python
from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder
text_to_embed = "I love pizza!"
text_to_embed = "I love pizza!"
text_embedder = MistralTextEmbedder()
text_embedder = MistralTextEmbedder()
print(text_embedder.run(text_to_embed))
print(text_embedder.run(text_to_embed))
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'text-embedding-ada-002-v2',
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
```
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'text-embedding-ada-002-v2',
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
```
"""

def __init__(
self,
api_key: Secret = Secret.from_env_var("MISTRAL_API_KEY"),
Expand All @@ -44,15 +45,12 @@ def __init__(
:param api_key: The Misttal API key.
:param model: The name of the Mistral embedding models to be used.
:param dimensions: Not yet supported with Mistral embedding models
:param organization: The Organization ID, defaults to `None`.
:param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`. For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param organization: The Organization ID, defaults to `None`.
:param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`.
For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param prefix: Not yet supported with Mistral embedding models
:param suffix: Not yet supported with Mistral embedding models
"""
super(MistralTextEmbedder, self).__init__(api_key,
model,
dimensions,
api_base_url,
organization,
prefix,
suffix)
super(MistralTextEmbedder, self).__init__(
api_key, model, dimensions, api_base_url, organization, prefix, suffix
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .chat.chat_generator import MistralChatGenerator

__all__ = ["MistralChatGenerator"]
__all__ = ["MistralChatGenerator"]
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import Any, Callable, Dict, Optional

from haystack import component
from haystack.dataclasses import StreamingChunk, ChatMessage
from haystack.utils.auth import Secret
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import StreamingChunk
from haystack.utils.auth import Secret


@component
class MistralChatGenerator(OpenAIChatGenerator):
"""
Enables text generation using Mistral's large language models (LLMs).
Enables text generation using Mistral's large language models (LLMs).
Currently supports `mistral-tiny`, `mistral-small` and `mistral-medium`
models accessed through the chat completions API endpoint.
Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs`
parameter in `run` method.
For more details on the parameters supported by the Mistral API, refer to the
For more details on the parameters supported by the Mistral API, refer to the
[Mistral API Docs](https://docs.mistral.ai/api/).
```python
Expand All @@ -36,7 +36,7 @@ class MistralChatGenerator(OpenAIChatGenerator):
>>{'replies': [ChatMessage(content='Natural Language Processing (NLP) is a branch of artificial intelligence
>>that focuses on enabling computers to understand, interpret, and generate human language in a way that is
>>meaningful and useful.', role=<ChatRole.ASSISTANT: 'assistant'>, name=None,
>>meta={'model': 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop',
>>meta={'model': 'mistral-tiny', 'index': 0, 'finish_reason': 'stop',
>>'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]}
```
Expand All @@ -48,8 +48,9 @@ class MistralChatGenerator(OpenAIChatGenerator):
Input and Output Format:
- **ChatMessage Format**: This component uses the ChatMessage format for structuring both input and output,
ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the
ChatMessage format can be found at: https://github.com/openai/openai-python/blob/main/chatml.md.
ensuring coherent and contextually relevant responses in chat-based text generation scenarios.
Details on the ChatMessage format can be found at: https://github.com/openai/openai-python/blob/main/chatml.md.
Note that the Mistral API does not accept `system` messages yet. You can use `user` and `assistant` messages.
"""

def __init__(
Expand All @@ -69,7 +70,8 @@ def __init__(
:param model: The name of the Mistral chat completion model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`. For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`.
For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param organization: Not yet supported with Mistral chat completion models
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
the Mistrak endpoint. See [Mistral API docs](https://docs.mistral.ai/api/t) for
Expand All @@ -81,15 +83,12 @@ def __init__(
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent
- `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent
events as they become available, with the stream terminated by a data: [DONE] message.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
- `random_seed`: The seed to use for random sampling. If set, different calls will generate deterministic results.
"""
super(MistralChatGenerator, self).__init__(api_key,
model,
streaming_callback,
api_base_url,
organization,
generation_kwargs)
- `random_seed`: The seed to use for random sampling.
"""
super().__init__(
api_key, model, streaming_callback, api_base_url, organization, generation_kwargs
)
Loading

0 comments on commit 266f30e

Please sign in to comment.