From a6a78088a5fd53995fd743c918285244b8cdd0e4 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 22 Nov 2024 13:09:30 +0100 Subject: [PATCH] fix: adapt to Ollama client 0.4.0 (#1209) * adapt to Ollama client 0.4.0 * remove explicit support for python 3.8 * fix linting --- integrations/ollama/pyproject.toml | 5 ++-- .../components/embedders/ollama/__init__.py | 2 +- .../embedders/ollama/document_embedder.py | 4 +-- .../embedders/ollama/text_embedder.py | 2 +- .../components/generators/ollama/__init__.py | 2 +- .../generators/ollama/chat/chat_generator.py | 17 +++++++----- .../components/generators/ollama/generator.py | 16 ++++++------ .../ollama/tests/test_chat_generator.py | 26 +++++++++---------- 8 files changed, 38 insertions(+), 36 deletions(-) diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index 598d1d214..c9fc22f3d 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -19,7 +19,6 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -27,7 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "ollama"] +dependencies = ["haystack-ai", "ollama>=0.4.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme" @@ -63,7 +62,7 @@ cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +python = ["3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py index 46042a1c9..822b3d0aa 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py @@ -1,4 +1,4 @@ from .document_embedder import OllamaDocumentEmbedder from .text_embedder import OllamaTextEmbedder -__all__ = ["OllamaTextEmbedder", "OllamaDocumentEmbedder"] +__all__ = ["OllamaDocumentEmbedder", "OllamaTextEmbedder"] diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py index ac8f38f35..2fab6c72f 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -100,7 +100,7 @@ def _embed_batch( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i] # Single batch only - result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs) + result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs).model_dump() all_embeddings.append(result["embedding"]) meta["model"] = self.model @@ -122,7 +122,7 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A - `documents`: Documents with embedding information attached - `meta`: The metadata collected during the embedding process """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "OllamaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a list of strings, please use the OllamaTextEmbedder." diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py index 7779c6d6e..b08b8bef3 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py @@ -62,7 +62,7 @@ def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None): - `embedding`: The computed embeddings - `meta`: The metadata collected during the embedding process """ - result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs) + result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs).model_dump() result["meta"] = {"model": self.model} return result diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py index 41a02d0ac..24e4d2edb 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py @@ -1,4 +1,4 @@ from .chat.chat_generator import OllamaChatGenerator from .generator import OllamaGenerator -__all__ = ["OllamaGenerator", "OllamaChatGenerator"] +__all__ = ["OllamaChatGenerator", "OllamaGenerator"] diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 558fd593e..b1be7a2db 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -4,7 +4,7 @@ from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from ollama import Client +from ollama import ChatResponse, Client @component @@ -111,12 +111,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: return {"role": message.role.value, "content": message.content} - def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage: + def _build_message_from_ollama_response(self, ollama_response: ChatResponse) -> ChatMessage: """ Converts the non-streaming response from the Ollama API to a ChatMessage. """ - message = ChatMessage.from_assistant(content=ollama_response["message"]["content"]) - message.meta.update({key: value for key, value in ollama_response.items() if key != "message"}) + response_dict = ollama_response.model_dump() + message = ChatMessage.from_assistant(content=response_dict["message"]["content"]) + message.meta.update({key: value for key, value in response_dict.items() if key != "message"}) return message def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: @@ -133,9 +134,11 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. """ - content = chunk_response["message"]["content"] - meta = {key: value for key, value in chunk_response.items() if key != "message"} - meta["role"] = chunk_response["message"]["role"] + chunk_response_dict = chunk_response.model_dump() + + content = chunk_response_dict["message"]["content"] + meta = {key: value for key, value in chunk_response_dict.items() if key != "message"} + meta["role"] = chunk_response_dict["message"]["role"] chunk_message = StreamingChunk(content, meta) return chunk_message diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 058948e8a..dad671c94 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -4,7 +4,7 @@ from haystack.dataclasses import StreamingChunk from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from ollama import Client +from ollama import Client, GenerateResponse @component @@ -118,15 +118,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) - def _convert_to_response(self, ollama_response: Dict[str, Any]) -> Dict[str, List[Any]]: + def _convert_to_response(self, ollama_response: GenerateResponse) -> Dict[str, List[Any]]: """ Converts a response from the Ollama API to the required Haystack format. """ + reply = ollama_response.response + meta = {key: value for key, value in ollama_response.model_dump().items() if key != "response"} - replies = [ollama_response["response"]] - meta = {key: value for key, value in ollama_response.items() if key != "response"} - - return {"replies": replies, "meta": [meta]} + return {"replies": [reply], "meta": [meta]} def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: """ @@ -154,8 +153,9 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. """ - content = chunk_response["response"] - meta = {key: value for key, value in chunk_response.items() if key != "response"} + chunk_response_dict = chunk_response.model_dump() + content = chunk_response_dict["response"] + meta = {key: value for key, value in chunk_response_dict.items() if key != "response"} chunk_message = StreamingChunk(content, meta) return chunk_message diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 5ac9289aa..b2b3fd927 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -4,7 +4,7 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole -from ollama._types import ResponseError +from ollama._types import ChatResponse, ResponseError from haystack_integrations.components.generators.ollama import OllamaChatGenerator @@ -86,18 +86,18 @@ def test_from_dict(self): def test_build_message_from_ollama_response(self): model = "some_model" - ollama_response = { - "model": model, - "created_at": "2023-12-12T14:13:43.416799Z", - "message": {"role": "assistant", "content": "Hello! How are you today?"}, - "done": True, - "total_duration": 5191566416, - "load_duration": 2154458, - "prompt_eval_count": 26, - "prompt_eval_duration": 383809000, - "eval_count": 298, - "eval_duration": 4799921000, - } + ollama_response = ChatResponse( + model=model, + created_at="2023-12-12T14:13:43.416799Z", + message={"role": "assistant", "content": "Hello! How are you today?"}, + done=True, + total_duration=5191566416, + load_duration=2154458, + prompt_eval_count=26, + prompt_eval_duration=383809000, + eval_count=298, + eval_duration=4799921000, + ) observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response)