Skip to content

Commit

Permalink
feat: Add streaming support to OllamaChatGenerator (#757)
Browse files Browse the repository at this point in the history
* Add streaming support to OllamaChatGenerator

* Clean imports, update docstring

* Organize imports

* Fix test

* Reformat code

* Optimize imports

* Add test for streaming callback

---------

Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
etiennellipse and silvanocerza authored May 31, 2024
1 parent 1457515 commit f799674
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Dict, List, Optional
import json
from typing import Any, Callable, Dict, List, Optional

import requests
from haystack import component
from haystack.dataclasses import ChatMessage
from haystack.dataclasses import ChatMessage, StreamingChunk
from requests import Response


Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
generation_kwargs: Optional[Dict[str, Any]] = None,
template: Optional[str] = None,
timeout: int = 120,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
:param model:
Expand All @@ -52,26 +54,30 @@ def __init__(
The full prompt template (overrides what is defined in the Ollama Modelfile).
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
"""

self.timeout = timeout
self.template = template
self.generation_kwargs = generation_kwargs or {}
self.url = url
self.model = model
self.streaming_callback = streaming_callback

def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]:
return {"role": message.role.value, "content": message.content}

def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=None) -> Dict[str, Any]:
def _create_json_payload(self, messages: List[ChatMessage], stream=False, generation_kwargs=None) -> Dict[str, Any]:
"""
Returns A dictionary of JSON arguments for a POST request to an Ollama service
"""
generation_kwargs = generation_kwargs or {}
return {
"messages": [self._message_to_dict(message) for message in messages],
"model": self.model,
"stream": False,
"stream": stream,
"template": self.template,
"options": generation_kwargs,
}
Expand All @@ -85,6 +91,41 @@ def _build_message_from_ollama_response(self, ollama_response: Response) -> Chat
message.meta.update({key: value for key, value in json_content.items() if key != "message"})
return message

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
"""
Converts a list of chunks response required Haystack format.
"""

replies = [ChatMessage.from_assistant("".join([c.content for c in chunks]))]
meta = {key: value for key, value in chunks[0].meta.items() if key != "message"}

return {"replies": replies, "meta": [meta]}

def _build_chunk(self, chunk_response: Any) -> StreamingChunk:
"""
Converts the response from the Ollama API to a StreamingChunk.
"""
decoded_chunk = json.loads(chunk_response.decode("utf-8"))

content = decoded_chunk["message"]["content"]
meta = {key: value for key, value in decoded_chunk.items() if key != "message"}
meta["role"] = decoded_chunk["message"]["role"]

chunk_message = StreamingChunk(content, meta)
return chunk_message

def _handle_streaming_response(self, response) -> List[StreamingChunk]:
"""
Handles Streaming response cases
"""
chunks: List[StreamingChunk] = []
for chunk in response.iter_lines():
chunk_delta: StreamingChunk = self._build_chunk(chunk)
chunks.append(chunk_delta)
if self.streaming_callback is not None:
self.streaming_callback(chunk_delta)
return chunks

@component.output_types(replies=List[ChatMessage])
def run(
self,
Expand All @@ -100,16 +141,24 @@ def run(
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param streaming_callback:
A callback function that will be called with each response chunk in streaming mode.
:returns: A dictionary with the following keys:
- `replies`: The responses from the model
"""
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

json_payload = self._create_json_payload(messages, generation_kwargs)
stream = self.streaming_callback is not None

json_payload = self._create_json_payload(messages, stream, generation_kwargs)

response = requests.post(url=self.url, json=json_payload, timeout=self.timeout)
response = requests.post(url=self.url, json=json_payload, timeout=self.timeout, stream=stream)

# throw error on unsuccessful response
response.raise_for_status()

if stream:
chunks: List[StreamingChunk] = self._handle_streaming_response(response)
return self._convert_to_streaming_response(chunks)

return {"replies": [self._build_message_from_ollama_response(response)]}
28 changes: 27 additions & 1 deletion integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def test_init(self):
assert component.timeout == 5

def test_create_json_payload(self, chat_messages):
observed = OllamaChatGenerator(model="some_model")._create_json_payload(chat_messages, {"temperature": 0.1})
observed = OllamaChatGenerator(model="some_model")._create_json_payload(
chat_messages, False, {"temperature": 0.1}
)
expected = {
"messages": [
{"role": "user", "content": "Tell me about why Super Mario is the greatest superhero"},
Expand Down Expand Up @@ -125,3 +127,27 @@ def test_run_model_unavailable(self):
"Based on your infinite wisdom, can you tell me why Alistair and Stefano are so great?"
)
component.run([message])

@pytest.mark.integration
def test_run_with_streaming(self):
streaming_callback = Mock()
chat_generator = OllamaChatGenerator(streaming_callback=streaming_callback)

chat_history = [
{"role": "user", "content": "What is the largest city in the United Kingdom by population?"},
{"role": "assistant", "content": "London is the largest city in the United Kingdom by population"},
{"role": "user", "content": "And what is the second largest?"},
]

chat_messages = [
ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None)
for message in chat_history
]

response = chat_generator.run(chat_messages)

streaming_callback.assert_called()

assert isinstance(response, dict)
assert isinstance(response["replies"], list)
assert "Manchester" in response["replies"][-1].content or "Glasgow" in response["replies"][-1].content

0 comments on commit f799674

Please sign in to comment.