Skip to content

Commit

Permalink
fix: adapt to Ollama client 0.4.0 (#1209)
Browse files Browse the repository at this point in the history
* adapt to Ollama client 0.4.0

* remove explicit support for python 3.8

* fix linting
  • Loading branch information
anakin87 authored Nov 22, 2024
1 parent 8d49172 commit a6a7808
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 36 deletions.
5 changes: 2 additions & 3 deletions integrations/ollama/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "ollama"]
dependencies = ["haystack-ai", "ollama>=0.4.0"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"
Expand Down Expand Up @@ -63,7 +62,7 @@ cov-retry = ["test-cov-retry", "cov-report"]
docs = ["pydoc-markdown pydoc/config.yml"]

[[tool.hatch.envs.all.matrix]]
python = ["3.8", "3.9", "3.10", "3.11", "3.12"]
python = ["3.9", "3.10", "3.11", "3.12"]


[tool.hatch.envs.lint]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .document_embedder import OllamaDocumentEmbedder
from .text_embedder import OllamaTextEmbedder

__all__ = ["OllamaTextEmbedder", "OllamaDocumentEmbedder"]
__all__ = ["OllamaDocumentEmbedder", "OllamaTextEmbedder"]
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _embed_batch(
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i] # Single batch only
result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs)
result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs).model_dump()
all_embeddings.append(result["embedding"])

meta["model"] = self.model
Expand All @@ -122,7 +122,7 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
- `documents`: Documents with embedding information attached
- `meta`: The metadata collected during the embedding process
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
msg = (
"OllamaDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a list of strings, please use the OllamaTextEmbedder."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None):
- `embedding`: The computed embeddings
- `meta`: The metadata collected during the embedding process
"""
result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs)
result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs).model_dump()
result["meta"] = {"model": self.model}

return result
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .chat.chat_generator import OllamaChatGenerator
from .generator import OllamaGenerator

__all__ = ["OllamaGenerator", "OllamaChatGenerator"]
__all__ = ["OllamaChatGenerator", "OllamaGenerator"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable

from ollama import Client
from ollama import ChatResponse, Client


@component
Expand Down Expand Up @@ -111,12 +111,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator":
def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]:
return {"role": message.role.value, "content": message.content}

def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage:
def _build_message_from_ollama_response(self, ollama_response: ChatResponse) -> ChatMessage:
"""
Converts the non-streaming response from the Ollama API to a ChatMessage.
"""
message = ChatMessage.from_assistant(content=ollama_response["message"]["content"])
message.meta.update({key: value for key, value in ollama_response.items() if key != "message"})
response_dict = ollama_response.model_dump()
message = ChatMessage.from_assistant(content=response_dict["message"]["content"])
message.meta.update({key: value for key, value in response_dict.items() if key != "message"})
return message

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
Expand All @@ -133,9 +134,11 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk:
"""
Converts the response from the Ollama API to a StreamingChunk.
"""
content = chunk_response["message"]["content"]
meta = {key: value for key, value in chunk_response.items() if key != "message"}
meta["role"] = chunk_response["message"]["role"]
chunk_response_dict = chunk_response.model_dump()

content = chunk_response_dict["message"]["content"]
meta = {key: value for key, value in chunk_response_dict.items() if key != "message"}
meta["role"] = chunk_response_dict["message"]["role"]

chunk_message = StreamingChunk(content, meta)
return chunk_message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from haystack.dataclasses import StreamingChunk
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable

from ollama import Client
from ollama import Client, GenerateResponse


@component
Expand Down Expand Up @@ -118,15 +118,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator":
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

def _convert_to_response(self, ollama_response: Dict[str, Any]) -> Dict[str, List[Any]]:
def _convert_to_response(self, ollama_response: GenerateResponse) -> Dict[str, List[Any]]:
"""
Converts a response from the Ollama API to the required Haystack format.
"""
reply = ollama_response.response
meta = {key: value for key, value in ollama_response.model_dump().items() if key != "response"}

replies = [ollama_response["response"]]
meta = {key: value for key, value in ollama_response.items() if key != "response"}

return {"replies": replies, "meta": [meta]}
return {"replies": [reply], "meta": [meta]}

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
"""
Expand Down Expand Up @@ -154,8 +153,9 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk:
"""
Converts the response from the Ollama API to a StreamingChunk.
"""
content = chunk_response["response"]
meta = {key: value for key, value in chunk_response.items() if key != "response"}
chunk_response_dict = chunk_response.model_dump()
content = chunk_response_dict["response"]
meta = {key: value for key, value in chunk_response_dict.items() if key != "response"}

chunk_message = StreamingChunk(content, meta)
return chunk_message
Expand Down
26 changes: 13 additions & 13 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ChatRole
from ollama._types import ResponseError
from ollama._types import ChatResponse, ResponseError

from haystack_integrations.components.generators.ollama import OllamaChatGenerator

Expand Down Expand Up @@ -86,18 +86,18 @@ def test_from_dict(self):
def test_build_message_from_ollama_response(self):
model = "some_model"

ollama_response = {
"model": model,
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {"role": "assistant", "content": "Hello! How are you today?"},
"done": True,
"total_duration": 5191566416,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000,
}
ollama_response = ChatResponse(
model=model,
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "Hello! How are you today?"},
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response)

Expand Down

0 comments on commit a6a7808

Please sign in to comment.