Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: adapt to Ollama client 0.4.0 #1209

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions integrations/ollama/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "ollama"]
dependencies = ["haystack-ai", "ollama>=0.4.0"]
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"
Expand Down Expand Up @@ -63,7 +62,7 @@ cov-retry = ["test-cov-retry", "cov-report"]
docs = ["pydoc-markdown pydoc/config.yml"]

[[tool.hatch.envs.all.matrix]]
python = ["3.8", "3.9", "3.10", "3.11", "3.12"]
python = ["3.9", "3.10", "3.11", "3.12"]


[tool.hatch.envs.lint]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .document_embedder import OllamaDocumentEmbedder
from .text_embedder import OllamaTextEmbedder

__all__ = ["OllamaTextEmbedder", "OllamaDocumentEmbedder"]
__all__ = ["OllamaDocumentEmbedder", "OllamaTextEmbedder"]
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _embed_batch(
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i] # Single batch only
result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs)
result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs).model_dump()
all_embeddings.append(result["embedding"])

meta["model"] = self.model
Expand All @@ -122,7 +122,7 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
- `documents`: Documents with embedding information attached
- `meta`: The metadata collected during the embedding process
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
msg = (
"OllamaDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a list of strings, please use the OllamaTextEmbedder."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None):
- `embedding`: The computed embeddings
- `meta`: The metadata collected during the embedding process
"""
result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs)
result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs).model_dump()
result["meta"] = {"model": self.model}

return result
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .chat.chat_generator import OllamaChatGenerator
from .generator import OllamaGenerator

__all__ = ["OllamaGenerator", "OllamaChatGenerator"]
__all__ = ["OllamaChatGenerator", "OllamaGenerator"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable

from ollama import Client
from ollama import ChatResponse, Client


@component
Expand Down Expand Up @@ -111,12 +111,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator":
def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]:
return {"role": message.role.value, "content": message.content}

def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage:
def _build_message_from_ollama_response(self, ollama_response: ChatResponse) -> ChatMessage:
"""
Converts the non-streaming response from the Ollama API to a ChatMessage.
"""
message = ChatMessage.from_assistant(content=ollama_response["message"]["content"])
message.meta.update({key: value for key, value in ollama_response.items() if key != "message"})
response_dict = ollama_response.model_dump()
message = ChatMessage.from_assistant(content=response_dict["message"]["content"])
message.meta.update({key: value for key, value in response_dict.items() if key != "message"})
return message

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
Expand All @@ -133,9 +134,11 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk:
"""
Converts the response from the Ollama API to a StreamingChunk.
"""
content = chunk_response["message"]["content"]
meta = {key: value for key, value in chunk_response.items() if key != "message"}
meta["role"] = chunk_response["message"]["role"]
chunk_response_dict = chunk_response.model_dump()

content = chunk_response_dict["message"]["content"]
meta = {key: value for key, value in chunk_response_dict.items() if key != "message"}
meta["role"] = chunk_response_dict["message"]["role"]

chunk_message = StreamingChunk(content, meta)
return chunk_message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from haystack.dataclasses import StreamingChunk
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable

from ollama import Client
from ollama import Client, GenerateResponse


@component
Expand Down Expand Up @@ -118,15 +118,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator":
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

def _convert_to_response(self, ollama_response: Dict[str, Any]) -> Dict[str, List[Any]]:
def _convert_to_response(self, ollama_response: GenerateResponse) -> Dict[str, List[Any]]:
"""
Converts a response from the Ollama API to the required Haystack format.
"""
reply = ollama_response.response
meta = {key: value for key, value in ollama_response.model_dump().items() if key != "response"}

replies = [ollama_response["response"]]
meta = {key: value for key, value in ollama_response.items() if key != "response"}

return {"replies": replies, "meta": [meta]}
return {"replies": [reply], "meta": [meta]}

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
"""
Expand Down Expand Up @@ -154,8 +153,9 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk:
"""
Converts the response from the Ollama API to a StreamingChunk.
"""
content = chunk_response["response"]
meta = {key: value for key, value in chunk_response.items() if key != "response"}
chunk_response_dict = chunk_response.model_dump()
content = chunk_response_dict["response"]
meta = {key: value for key, value in chunk_response_dict.items() if key != "response"}

chunk_message = StreamingChunk(content, meta)
return chunk_message
Expand Down
26 changes: 13 additions & 13 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ChatRole
from ollama._types import ResponseError
from ollama._types import ChatResponse, ResponseError

from haystack_integrations.components.generators.ollama import OllamaChatGenerator

Expand Down Expand Up @@ -86,18 +86,18 @@ def test_from_dict(self):
def test_build_message_from_ollama_response(self):
model = "some_model"

ollama_response = {
"model": model,
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {"role": "assistant", "content": "Hello! How are you today?"},
"done": True,
"total_duration": 5191566416,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000,
}
ollama_response = ChatResponse(
model=model,
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "Hello! How are you today?"},
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response)

Expand Down