Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…tegrations into weaviate-client-v4
  • Loading branch information
hsm207 committed Feb 28, 2024
2 parents 5b9d2b2 + d31442d commit 6c4fd71
Show file tree
Hide file tree
Showing 20 changed files with 654 additions and 198 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/ollama.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ env:
PYTHONUNBUFFERED: "1"
FORCE_COLOR: "1"
LLM_FOR_TESTS: "orca-mini"
EMBEDDER_FOR_TESTS: "nomic-embed-text"

jobs:
run:
Expand Down Expand Up @@ -55,7 +56,10 @@ jobs:
run: hatch run lint:all

- name: Pull the LLM in the Ollama service
run: docker exec ollama ollama pull ${{ env.LLM_FOR_TESTS }}
run: docker exec ollama ollama pull ${{ env.LLM_FOR_TESTS }}

- name: Pull the Embedding Model in the Ollama service
run: docker exec ollama ollama pull ${{ env.EMBEDDER_FOR_TESTS }}

- name: Generate docs
if: matrix.python-version == '3.9' && runner.os == 'Linux'
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/optimum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ jobs:
if: matrix.python-version == '3.9' && runner.os == 'Linux'
run: hatch run lint:all

- name: Generate docs
if: matrix.python-version == '3.9' && runner.os == 'Linux'
run: hatch run docs
# - name: Generate docs
# if: matrix.python-version == '3.9' && runner.os == 'Linux'
# run: hatch run docs

- name: Run tests
run: hatch run cov
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,27 @@ def __init__(
embedding_separator: str = "\n",
):
"""
Create a MistralDocumentEmbedder component.
:param api_key: The Mistral API key.
:param model: The name of the model to use.
:param api_base_url: The Mistral API Base url, defaults to None. For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of Documents to encode at once.
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments
to keep the logs clean.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
Creates a MistralDocumentEmbedder component.
:param api_key:
The Mistral API key.
:param model:
The name of the model to use.
:param api_base_url:
The Mistral API Base url. For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param prefix:
A string to add to the beginning of each text.
:param suffix:
A string to add to the end of each text.
:param batch_size:
Number of Documents to encode at once.
:param progress_bar:
Whether to show a progress bar or not. Can be helpful to disable in production deployments to keep
the logs clean.
:param meta_fields_to_embed:
List of meta fields that should be embedded along with the Document text.
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
"""
super(MistralDocumentEmbedder, self).__init__( # noqa: UP008
api_key=api_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,21 @@
@component
class MistralTextEmbedder(OpenAITextEmbedder):
"""
A component for embedding strings using Mistral models.
A component for embedding strings using Mistral models.
Usage example:
Usage example:
```python
from haystack_integrations.components.embedders.mistral.text_embedder import MistralTextEmbedder
text_to_embed = "I love pizza!"
text_to_embed = "I love pizza!"
text_embedder = MistralTextEmbedder()
print(text_embedder.run(text_to_embed))
text_embedder = MistralTextEmbedder()
print(text_embedder.run(text_to_embed))
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'text-embedding-ada-002-v2',
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
```
# output:
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'mistral-embed',
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
```
"""

def __init__(
Expand All @@ -38,14 +37,19 @@ def __init__(
suffix: str = "",
):
"""
Create an MistralTextEmbedder component.
:param api_key: The Misttal API key.
:param model: The name of the Mistral embedding models to be used.
:param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`.
For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
Creates an MistralTextEmbedder component.
:param api_key:
The Mistral API key.
:param model:
The name of the Mistral embedding model to be used.
:param api_base_url:
The Mistral API Base url.
For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param prefix:
A string to add to the beginning of each text.
:param suffix:
A string to add to the end of each text.
"""
super(MistralTextEmbedder, self).__init__( # noqa: UP008
api_key=api_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,27 @@
@component
class MistralChatGenerator(OpenAIChatGenerator):
"""
Enables text generation using Mistral's large language models (LLMs).
Currently supports `mistral-tiny`, `mistral-small` and `mistral-medium`
models accessed through the chat completions API endpoint.
Enables text generation using Mistral AI generative models.
For supported models, see [Mistral AI docs](https://docs.mistral.ai/platform/endpoints/#operation/listModels).
Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs`
Users can pass any text generation parameters valid for the Mistral Chat Completion API
directly to this component via the `generation_kwargs` parameter in `__init__` or the `generation_kwargs`
parameter in `run` method.
Key Features and Compatibility:
- **Primary Compatibility**: Designed to work seamlessly with the Mistral API Chat Completion endpoint.
- **Streaming Support**: Supports streaming responses from the Mistral API Chat Completion endpoint.
- **Customizability**: Supports all parameters supported by the Mistral API Chat Completion endpoint.
This component uses the ChatMessage format for structuring both input and output,
ensuring coherent and contextually relevant responses in chat-based text generation scenarios.
Details on the ChatMessage format can be found in the
[Haystack docs](https://docs.haystack.deepset.ai/v2.0/docs/data-classes#chatmessage)
For more details on the parameters supported by the Mistral API, refer to the
[Mistral API Docs](https://docs.mistral.ai/api/).
Usage example:
```python
from haystack_integrations.components.generators.mistral import MistralChatGenerator
from haystack.dataclasses import ChatMessage
Expand All @@ -38,19 +48,7 @@ class MistralChatGenerator(OpenAIChatGenerator):
>>meaningful and useful.', role=<ChatRole.ASSISTANT: 'assistant'>, name=None,
>>meta={'model': 'mistral-tiny', 'index': 0, 'finish_reason': 'stop',
>>'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]}
```
Key Features and Compatibility:
- **Primary Compatibility**: Designed to work seamlessly with the Mistral API Chat Completion endpoint.
- **Streaming Support**: Supports streaming responses from the Mistral API Chat Completion endpoint.
- **Customizability**: Supports all parameters supported by the Mistral API Chat Completion endpoint.
Input and Output Format:
- **ChatMessage Format**: This component uses the ChatMessage format for structuring both input and output,
ensuring coherent and contextually relevant responses in chat-based text generation scenarios.
Details on the ChatMessage format can be found at: https://github.com/openai/openai-python/blob/main/chatml.md.
Note that the Mistral API does not accept `system` messages yet. You can use `user` and `assistant` messages.
"""

def __init__(
Expand All @@ -65,15 +63,19 @@ def __init__(
Creates an instance of MistralChatGenerator. Unless specified otherwise in the `model`, this is for Mistral's
`mistral-tiny` model.
:param api_key: The Mistral API key.
:param model: The name of the Mistral chat completion model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
:param api_key:
The Mistral API key.
:param model:
The name of the Mistral chat completion model to use.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param api_base_url: The Mistral API Base url, defaults to `https://api.mistral.ai/v1`.
For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
the Mistrak endpoint. See [Mistral API docs](https://docs.mistral.ai/api/t) for
more details.
:param api_base_url:
The Mistral API Base url.
For more details, see Mistral [docs](https://docs.mistral.ai/api/).
:param generation_kwargs:
Other parameters to use for the model. These parameters are all sent directly to
the Mistral endpoint. See [Mistral API docs](https://docs.mistral.ai/api/) for more details.
Some of the supported parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
Expand All @@ -83,7 +85,6 @@ def __init__(
comprising the top 10% probability mass are considered.
- `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent
events as they become available, with the stream terminated by a data: [DONE] message.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
- `random_seed`: The seed to use for random sampling.
"""
Expand Down
6 changes: 4 additions & 2 deletions integrations/ollama/pydoc/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ loaders:
search_path: [../src]
modules: [
"haystack_integrations.components.generators.ollama.generator",
"haystack_integrations.components.generators.ollama.chat.chat_generator"
"haystack_integrations.components.generators.ollama.chat.chat_generator",
"haystack_integrations.components.embedders.ollama.document_embedder",
"haystack_integrations.components.embedders.ollama.text_embedder",
]
ignore_when_discovered: ["__init__"]
processors:
Expand All @@ -26,4 +28,4 @@ renderer:
descriptive_module_title: true
add_method_class_prefix: true
add_member_class_prefix: false
filename: _readme_ollama.md
filename: _readme_ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class OllamaDocumentEmbedder:
def __init__(
self,
model: str = "orca-mini",
model: str = "nomic-embed-text",
url: str = "http://localhost:11434/api/embeddings",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
Expand All @@ -21,7 +21,7 @@ def __init__(
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "orca-mini".
Default is "nomic-embed-text". "https://ollama.com/library/nomic-embed-text"
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/embeddings".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
class OllamaTextEmbedder:
def __init__(
self,
model: str = "orca-mini",
model: str = "nomic-embed-text",
url: str = "http://localhost:11434/api/embeddings",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "orca-mini".
Default is "nomic-embed-text". "https://ollama.com/library/nomic-embed-text"
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/embeddings".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
Expand Down
12 changes: 6 additions & 6 deletions integrations/ollama/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def test_init_defaults(self):
assert embedder.timeout == 120
assert embedder.generation_kwargs == {}
assert embedder.url == "http://localhost:11434/api/embeddings"
assert embedder.model == "orca-mini"
assert embedder.model == "nomic-embed-text"

def test_init(self):
embedder = OllamaDocumentEmbedder(
model="orca-mini",
model="nomic-embed-text",
url="http://my-custom-endpoint:11434/api/embeddings",
generation_kwargs={"temperature": 0.5},
timeout=3000,
Expand All @@ -24,7 +24,7 @@ def test_init(self):
assert embedder.timeout == 3000
assert embedder.generation_kwargs == {"temperature": 0.5}
assert embedder.url == "http://my-custom-endpoint:11434/api/embeddings"
assert embedder.model == "orca-mini"
assert embedder.model == "nomic-embed-text"

@pytest.mark.integration
def test_model_not_found(self):
Expand All @@ -35,17 +35,17 @@ def test_model_not_found(self):

@pytest.mark.integration
def import_text_in_embedder(self):
embedder = OllamaDocumentEmbedder(model="orca-mini")
embedder = OllamaDocumentEmbedder(model="nomic-embed-text")

with pytest.raises(TypeError):
embedder.run("This is a text string. This should not work.")

@pytest.mark.integration
def test_run(self):
embedder = OllamaDocumentEmbedder(model="orca-mini")
embedder = OllamaDocumentEmbedder(model="nomic-embed-text")
list_of_docs = [Document(content="This is a document containing some text.")]
reply = embedder.run(list_of_docs)

assert isinstance(reply, dict)
assert all(isinstance(element, float) for element in reply["documents"][0].embedding)
assert reply["meta"]["model"] == "orca-mini"
assert reply["meta"]["model"] == "nomic-embed-text"
6 changes: 3 additions & 3 deletions integrations/ollama/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_init_defaults(self):
assert embedder.timeout == 120
assert embedder.generation_kwargs == {}
assert embedder.url == "http://localhost:11434/api/embeddings"
assert embedder.model == "orca-mini"
assert embedder.model == "nomic-embed-text"

def test_init(self):
embedder = OllamaTextEmbedder(
Expand All @@ -34,10 +34,10 @@ def test_model_not_found(self):

@pytest.mark.integration
def test_run(self):
embedder = OllamaTextEmbedder(model="orca-mini")
embedder = OllamaTextEmbedder(model="nomic-embed-text")

reply = embedder.run("hello")

assert isinstance(reply, dict)
assert all(isinstance(element, float) for element in reply["embedding"])
assert reply["meta"]["model"] == "orca-mini"
assert reply["meta"]["model"] == "nomic-embed-text"
2 changes: 2 additions & 0 deletions integrations/optimum/pydoc/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ loaders:
"haystack_integrations.components.embedders.optimum.optimum_document_embedder",
"haystack_integrations.components.embedders.optimum.optimum_text_embedder",
"haystack_integrations.components.embedders.optimum.pooling",
"haystack_integrations.components.embedders.optimum.optimization",
"haystack_integrations.components.embedders.optimum.quantization",
]
ignore_when_discovered: ["__init__"]
processors:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,18 @@
#
# SPDX-License-Identifier: Apache-2.0

from .optimization import OptimumEmbedderOptimizationConfig, OptimumEmbedderOptimizationMode
from .optimum_document_embedder import OptimumDocumentEmbedder
from .optimum_text_embedder import OptimumTextEmbedder
from .pooling import OptimumEmbedderPooling
from .quantization import OptimumEmbedderQuantizationConfig, OptimumEmbedderQuantizationMode

__all__ = ["OptimumDocumentEmbedder", "OptimumEmbedderPooling", "OptimumTextEmbedder"]
__all__ = [
"OptimumDocumentEmbedder",
"OptimumEmbedderOptimizationMode",
"OptimumEmbedderOptimizationConfig",
"OptimumEmbedderPooling",
"OptimumEmbedderQuantizationMode",
"OptimumEmbedderQuantizationConfig",
"OptimumTextEmbedder",
]
Loading

0 comments on commit 6c4fd71

Please sign in to comment.