Skip to content

Commit

Permalink
Support for streaming ollama generator (#280)
Browse files Browse the repository at this point in the history
* added support for streaming ollama generator

* fixed linting errors

* more linting issues

* implemented serializing/deserializing for OllamaGenerator

* linting changes

* minor refactoring

* formating
  • Loading branch information
sachinsachdeva authored Feb 12, 2024
1 parent 1a1f5a2 commit 09831a6
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, Dict, List, Optional
import json
from typing import Any, Callable, Dict, List, Optional

import requests
from haystack import component
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler
from haystack.dataclasses import StreamingChunk
from requests import Response


Expand All @@ -21,6 +24,7 @@ def __init__(
template: Optional[str] = None,
raw: bool = False,
timeout: int = 120,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Expand All @@ -36,6 +40,8 @@ def __init__(
if you are specifying a full templated prompt in your API request.
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 120 seconds.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
"""
self.timeout = timeout
self.raw = raw
Expand All @@ -44,8 +50,40 @@ def __init__(
self.model = model
self.url = url
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback

def _create_json_payload(self, prompt: str, generation_kwargs=None) -> Dict[str, Any]:
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
timeout=self.timeout,
raw=self.raw,
template=self.template,
system_prompt=self.system_prompt,
model=self.model,
url=self.url,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
return default_from_dict(cls, data)

def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]:
"""
Returns a dictionary of JSON arguments for a POST request to an Ollama service.
:param prompt: The prompt to generate a response for.
Expand All @@ -58,26 +96,67 @@ def _create_json_payload(self, prompt: str, generation_kwargs=None) -> Dict[str,
return {
"prompt": prompt,
"model": self.model,
"stream": False,
"stream": stream,
"raw": self.raw,
"template": self.template,
"system": self.system_prompt,
"options": generation_kwargs,
}

def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, List[Any]]:
def _convert_to_response(self, ollama_response: Response) -> Dict[str, List[Any]]:
"""
Convert a response from the Ollama API to the required Haystack format.
:param ollama_response: A response (requests library) from the Ollama API.
:return: A dictionary of the returned responses and metadata.
"""

resp_dict = ollama_response.json()

replies = [resp_dict["response"]]
meta = {key: value for key, value in resp_dict.items() if key != "response"}

return {"replies": replies, "meta": [meta]}

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
"""
Convert a list of chunks response required Haystack format.
:param chunks: List of StreamingChunks
:return: A dictionary of the returned responses and metadata.
"""

replies = ["".join([c.content for c in chunks])]
meta = {key: value for key, value in chunks[0].meta.items() if key != "response"}

return {"replies": replies, "meta": [meta]}

def _handle_streaming_response(self, response) -> List[StreamingChunk]:
"""Handles Streaming response case
:param response: streaming response from ollama api.
:return: The List[StreamingChunk].
"""
chunks: List[StreamingChunk] = []
for chunk in response.iter_lines():
chunk_delta: StreamingChunk = self._build_chunk(chunk)
chunks.append(chunk_delta)
if self.streaming_callback is not None:
self.streaming_callback(chunk_delta)
return chunks

def _build_chunk(self, chunk_response: Any) -> StreamingChunk:
"""
Converts the response from the Ollama API to a StreamingChunk.
:param chunk: The chunk returned by the Ollama API.
:return: The StreamingChunk.
"""
decoded_chunk = json.loads(chunk_response.decode("utf-8"))

content = decoded_chunk["response"]
meta = {key: value for key, value in decoded_chunk.items() if key != "response"}

chunk_message = StreamingChunk(content, meta)
return chunk_message

@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
def run(
self,
Expand All @@ -94,11 +173,17 @@ def run(
"""
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

json_payload = self._create_json_payload(prompt, generation_kwargs)
stream = self.streaming_callback is not None

json_payload = self._create_json_payload(prompt, stream, generation_kwargs)

response = requests.post(url=self.url, json=json_payload, timeout=self.timeout)
response = requests.post(url=self.url, json=json_payload, timeout=self.timeout, stream=stream)

# throw error on unsuccessful response
response.raise_for_status()

return self._convert_to_haystack_response(response)
if stream:
chunks: List[StreamingChunk] = self._handle_streaming_response(response)
return self._convert_to_streaming_response(chunks)

return self._convert_to_response(response)
98 changes: 94 additions & 4 deletions integrations/ollama/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import pytest
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk
from haystack_integrations.components.generators.ollama import OllamaGenerator
from requests import HTTPError

Expand Down Expand Up @@ -42,23 +46,86 @@ def test_init_default(self):
assert component.template is None
assert component.raw is False
assert component.timeout == 120
assert component.streaming_callback is None

def test_init(self):
def callback(x: StreamingChunk):
x.content = ""

component = OllamaGenerator(
model="llama2",
url="http://my-custom-endpoint:11434/api/generate",
generation_kwargs={"temperature": 0.5},
system_prompt="You are Luigi from Super Mario Bros.",
timeout=5,
streaming_callback=callback,
)

assert component.model == "llama2"
assert component.url == "http://my-custom-endpoint:11434/api/generate"
assert component.generation_kwargs == {"temperature": 0.5}
assert component.system_prompt == "You are Luigi from Super Mario Bros."
assert component.template is None
assert component.raw is False
assert component.timeout == 5
assert component.streaming_callback == callback

component = OllamaGenerator()
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator",
"init_parameters": {
"timeout": 120,
"raw": False,
"template": None,
"system_prompt": None,
"model": "orca-mini",
"url": "http://localhost:11434/api/generate",
"streaming_callback": None,
"generation_kwargs": {},
},
}

def test_to_dict_with_parameters(self):
component = OllamaGenerator(
model="llama2",
streaming_callback=print_streaming_chunk,
url="going_to_51_pegasi_b_for_weekend",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator",
"init_parameters": {
"timeout": 120,
"raw": False,
"template": None,
"system_prompt": None,
"model": "llama2",
"url": "going_to_51_pegasi_b_for_weekend",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}

def test_from_dict(self):
data = {
"type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator",
"init_parameters": {
"timeout": 120,
"raw": False,
"template": None,
"system_prompt": None,
"model": "llama2",
"url": "going_to_51_pegasi_b_for_weekend",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = OllamaGenerator.from_dict(data)
assert component.model == "llama2"
assert component.streaming_callback is print_streaming_chunk
assert component.url == "going_to_51_pegasi_b_for_weekend"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

@pytest.mark.parametrize(
"configuration",
Expand All @@ -79,20 +146,43 @@ def test_init(self):
},
],
)
def test_create_json_payload(self, configuration):
@pytest.mark.parametrize("stream", [True, False])
def test_create_json_payload(self, configuration: dict[str, Any], stream: bool):
prompt = "hello"
component = OllamaGenerator(**configuration)

observed = component._create_json_payload(prompt=prompt)
observed = component._create_json_payload(prompt=prompt, stream=stream)

expected = {
"prompt": prompt,
"model": configuration["model"],
"stream": False,
"stream": stream,
"system": configuration["system_prompt"],
"raw": configuration["raw"],
"template": configuration["template"],
"options": {},
}

assert observed == expected

@pytest.mark.integration
def test_ollama_generator_run_streaming(self):
class Callback:
def __init__(self):
self.responses = ""
self.count_calls = 0

def __call__(self, chunk):
self.responses += chunk.content
self.count_calls += 1
return chunk

callback = Callback()
component = OllamaGenerator(streaming_callback=callback)
results = component.run(prompt="What's the capital of Netherlands?")

assert len(results["replies"]) == 1
assert "Amsterdam" in results["replies"][0]
assert len(results["meta"]) == 1
assert callback.responses == results["replies"][0]
assert callback.count_calls > 1

0 comments on commit 09831a6

Please sign in to comment.