diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 55bd65d8a..321eab9f3 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -1,7 +1,10 @@ -from typing import Any, Dict, List, Optional +import json +from typing import Any, Callable, Dict, List, Optional import requests -from haystack import component +from haystack import component, default_from_dict, default_to_dict +from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler +from haystack.dataclasses import StreamingChunk from requests import Response @@ -21,6 +24,7 @@ def __init__( template: Optional[str] = None, raw: bool = False, timeout: int = 120, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ :param model: The name of the model to use. The model should be available in the running Ollama instance. @@ -36,6 +40,8 @@ def __init__( if you are specifying a full templated prompt in your API request. :param timeout: The number of seconds before throwing a timeout error from the Ollama API. Default is 120 seconds. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ self.timeout = timeout self.raw = raw @@ -44,8 +50,40 @@ def __init__( self.model = model self.url = url self.generation_kwargs = generation_kwargs or {} + self.streaming_callback = streaming_callback - def _create_json_payload(self, prompt: str, generation_kwargs=None) -> Dict[str, Any]: + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + :return: The serialized component as a dictionary. + """ + callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + timeout=self.timeout, + raw=self.raw, + template=self.template, + system_prompt=self.system_prompt, + model=self.model, + url=self.url, + generation_kwargs=self.generation_kwargs, + streaming_callback=callback_name, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + return default_from_dict(cls, data) + + def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]: """ Returns a dictionary of JSON arguments for a POST request to an Ollama service. :param prompt: The prompt to generate a response for. @@ -58,19 +96,20 @@ def _create_json_payload(self, prompt: str, generation_kwargs=None) -> Dict[str, return { "prompt": prompt, "model": self.model, - "stream": False, + "stream": stream, "raw": self.raw, "template": self.template, "system": self.system_prompt, "options": generation_kwargs, } - def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, List[Any]]: + def _convert_to_response(self, ollama_response: Response) -> Dict[str, List[Any]]: """ Convert a response from the Ollama API to the required Haystack format. :param ollama_response: A response (requests library) from the Ollama API. :return: A dictionary of the returned responses and metadata. """ + resp_dict = ollama_response.json() replies = [resp_dict["response"]] @@ -78,6 +117,46 @@ def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, return {"replies": replies, "meta": [meta]} + def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: + """ + Convert a list of chunks response required Haystack format. + :param chunks: List of StreamingChunks + :return: A dictionary of the returned responses and metadata. + """ + + replies = ["".join([c.content for c in chunks])] + meta = {key: value for key, value in chunks[0].meta.items() if key != "response"} + + return {"replies": replies, "meta": [meta]} + + def _handle_streaming_response(self, response) -> List[StreamingChunk]: + """Handles Streaming response case + + :param response: streaming response from ollama api. + :return: The List[StreamingChunk]. + """ + chunks: List[StreamingChunk] = [] + for chunk in response.iter_lines(): + chunk_delta: StreamingChunk = self._build_chunk(chunk) + chunks.append(chunk_delta) + if self.streaming_callback is not None: + self.streaming_callback(chunk_delta) + return chunks + + def _build_chunk(self, chunk_response: Any) -> StreamingChunk: + """ + Converts the response from the Ollama API to a StreamingChunk. + :param chunk: The chunk returned by the Ollama API. + :return: The StreamingChunk. + """ + decoded_chunk = json.loads(chunk_response.decode("utf-8")) + + content = decoded_chunk["response"] + meta = {key: value for key, value in decoded_chunk.items() if key != "response"} + + chunk_message = StreamingChunk(content, meta) + return chunk_message + @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) def run( self, @@ -94,11 +173,17 @@ def run( """ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - json_payload = self._create_json_payload(prompt, generation_kwargs) + stream = self.streaming_callback is not None + + json_payload = self._create_json_payload(prompt, stream, generation_kwargs) - response = requests.post(url=self.url, json=json_payload, timeout=self.timeout) + response = requests.post(url=self.url, json=json_payload, timeout=self.timeout, stream=stream) # throw error on unsuccessful response response.raise_for_status() - return self._convert_to_haystack_response(response) + if stream: + chunks: List[StreamingChunk] = self._handle_streaming_response(response) + return self._convert_to_streaming_response(chunks) + + return self._convert_to_response(response) diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 0403516ce..4af2bdb82 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -2,7 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Any + import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.ollama import OllamaGenerator from requests import HTTPError @@ -42,16 +46,20 @@ def test_init_default(self): assert component.template is None assert component.raw is False assert component.timeout == 120 + assert component.streaming_callback is None def test_init(self): + def callback(x: StreamingChunk): + x.content = "" + component = OllamaGenerator( model="llama2", url="http://my-custom-endpoint:11434/api/generate", generation_kwargs={"temperature": 0.5}, system_prompt="You are Luigi from Super Mario Bros.", timeout=5, + streaming_callback=callback, ) - assert component.model == "llama2" assert component.url == "http://my-custom-endpoint:11434/api/generate" assert component.generation_kwargs == {"temperature": 0.5} @@ -59,6 +67,65 @@ def test_init(self): assert component.template is None assert component.raw is False assert component.timeout == 5 + assert component.streaming_callback == callback + + component = OllamaGenerator() + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator", + "init_parameters": { + "timeout": 120, + "raw": False, + "template": None, + "system_prompt": None, + "model": "orca-mini", + "url": "http://localhost:11434/api/generate", + "streaming_callback": None, + "generation_kwargs": {}, + }, + } + + def test_to_dict_with_parameters(self): + component = OllamaGenerator( + model="llama2", + streaming_callback=print_streaming_chunk, + url="going_to_51_pegasi_b_for_weekend", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator", + "init_parameters": { + "timeout": 120, + "raw": False, + "template": None, + "system_prompt": None, + "model": "llama2", + "url": "going_to_51_pegasi_b_for_weekend", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + + def test_from_dict(self): + data = { + "type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator", + "init_parameters": { + "timeout": 120, + "raw": False, + "template": None, + "system_prompt": None, + "model": "llama2", + "url": "going_to_51_pegasi_b_for_weekend", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + component = OllamaGenerator.from_dict(data) + assert component.model == "llama2" + assert component.streaming_callback is print_streaming_chunk + assert component.url == "going_to_51_pegasi_b_for_weekend" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} @pytest.mark.parametrize( "configuration", @@ -79,16 +146,17 @@ def test_init(self): }, ], ) - def test_create_json_payload(self, configuration): + @pytest.mark.parametrize("stream", [True, False]) + def test_create_json_payload(self, configuration: dict[str, Any], stream: bool): prompt = "hello" component = OllamaGenerator(**configuration) - observed = component._create_json_payload(prompt=prompt) + observed = component._create_json_payload(prompt=prompt, stream=stream) expected = { "prompt": prompt, "model": configuration["model"], - "stream": False, + "stream": stream, "system": configuration["system_prompt"], "raw": configuration["raw"], "template": configuration["template"], @@ -96,3 +164,25 @@ def test_create_json_payload(self, configuration): } assert observed == expected + + @pytest.mark.integration + def test_ollama_generator_run_streaming(self): + class Callback: + def __init__(self): + self.responses = "" + self.count_calls = 0 + + def __call__(self, chunk): + self.responses += chunk.content + self.count_calls += 1 + return chunk + + callback = Callback() + component = OllamaGenerator(streaming_callback=callback) + results = component.run(prompt="What's the capital of Netherlands?") + + assert len(results["replies"]) == 1 + assert "Amsterdam" in results["replies"][0] + assert len(results["meta"]) == 1 + assert callback.responses == results["replies"][0] + assert callback.count_calls > 1