Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: adapt to Ollama client 0.4.0 #139

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# pylint: disable=import-error
from haystack_integrations.components.generators.ollama import OllamaChatGenerator as OllamaChatGeneratorBase

from ollama import ChatResponse


# The following code block ensures that:
# - we reuse existing code where possible
Expand Down Expand Up @@ -175,11 +177,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator":

return default_from_dict(cls, data)

def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage:
def _build_message_from_ollama_response(self, ollama_response: "ChatResponse") -> ChatMessage:
"""
Converts the non-streaming response from the Ollama API to a ChatMessage.
"""
ollama_message = ollama_response["message"]
response_dict = ollama_response.model_dump()

ollama_message = response_dict["message"]

text = ollama_message["content"]

Expand All @@ -192,7 +196,7 @@ def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -

message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls)

message.meta.update({key: value for key, value in ollama_response.items() if key != "message"})
message.meta.update({key: value for key, value in response_dict.items() if key != "message"})
return message

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ extra-dependencies = [
"fastapi",
# Tools support
"jsonschema",
"ollama-haystack>=1.1.0",
"ollama-haystack>=2.0",
# Async
"opensearch-haystack",
"opensearch-py[async]",
Expand Down
110 changes: 55 additions & 55 deletions test/components/generators/ollama/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk
from ollama._types import ResponseError
from ollama._types import ResponseError, ChatResponse

from haystack_experimental.dataclasses import (
ChatMessage,
Expand Down Expand Up @@ -225,18 +225,18 @@ def test_from_dict(self):
def test_build_message_from_ollama_response(self):
model = "some_model"

ollama_response = {
"model": model,
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {"role": "assistant", "content": "Hello! How are you today?"},
"done": True,
"total_duration": 5191566416,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000,
}
ollama_response = ChatResponse(
model=model,
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "Hello! How are you today?"},
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response)

Expand All @@ -246,10 +246,10 @@ def test_build_message_from_ollama_response(self):
def test_build_message_from_ollama_response_with_tools(self):
model = "some_model"

ollama_response = {
"model": model,
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {
ollama_response = ChatResponse(
model=model,
created_at="2023-12-12T14:13:43.416799Z",
message={
"role": "assistant",
"content": "",
"tool_calls": [
Expand All @@ -261,14 +261,14 @@ def test_build_message_from_ollama_response_with_tools(self):
}
],
},
"done": True,
"total_duration": 5191566416,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000,
}
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response)

Expand All @@ -283,21 +283,21 @@ def test_build_message_from_ollama_response_with_tools(self):
def test_run(self, mock_client):
generator = OllamaChatGenerator()

mock_response = {
"model": "llama3.2",
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {
mock_response = ChatResponse(
model="llama3.2",
created_at="2023-12-12T14:13:43.416799Z",
message={
"role": "assistant",
"content": "Fine. How can I help you today?",
},
"done": True,
"total_duration": 5191566416,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000,
}
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

mock_client_instance = mock_client.return_value
mock_client_instance.chat.return_value = mock_response
Expand Down Expand Up @@ -330,24 +330,24 @@ def streaming_callback(chunk: StreamingChunk) -> None:

mock_response = iter(
[
{
"model": "llama3.2",
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {"role": "assistant", "content": "first chunk "},
"done": False,
},
{
"model": "llama3.2",
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {"role": "assistant", "content": "second chunk"},
"done": True,
"total_duration": 4883583458,
"load_duration": 1334875,
"prompt_eval_count": 26,
"prompt_eval_duration": 342546000,
"eval_count": 282,
"eval_duration": 4535599000,
},
ChatResponse(
model="llama3.2",
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "first chunk "},
done=False,
),
ChatResponse(
model="llama3.2",
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "second chunk"},
done=True,
total_duration=4883583458,
load_duration=1334875,
prompt_eval_count=26,
prompt_eval_duration=342546000,
eval_count=282,
eval_duration=4535599000,
),
]
)

Expand Down
Loading