diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index f9906c3fb..2582cf9c8 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -2,9 +2,10 @@ from typing import Any, Callable, Dict, List, Optional import requests -from haystack import component +from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk from requests import Response +from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler @component @@ -51,6 +52,37 @@ def __init__( self.generation_kwargs = generation_kwargs or {} self.streaming_callback = streaming_callback + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + :return: The serialized component as a dictionary. + """ + callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + timeout=self.timeout, + raw=self.raw, + template=self.template, + system_prompt=self.system_prompt, + model=self.model, + url=self.url, + generation_kwargs=self.generation_kwargs, + streaming_callback=callback_name, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + return default_from_dict(cls, data) + def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]: """ Returns a dictionary of JSON arguments for a POST request to an Ollama service. diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 59d778333..4bae7eb67 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -2,9 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Any +from _pytest.monkeypatch import MonkeyPatch import pytest from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.ollama import OllamaGenerator +from haystack.components.generators.utils import print_streaming_chunk from requests import HTTPError @@ -66,6 +69,64 @@ def callback(x: StreamingChunk): assert component.timeout == 5 assert component.streaming_callback == callback + component = OllamaGenerator() + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator", + "init_parameters": { + "timeout": 120, + "raw": False, + "template": None, + "system_prompt": None, + "model": "orca-mini", + "url": "http://localhost:11434/api/generate", + "streaming_callback": None, + "generation_kwargs": {}, + }, + } + + def test_to_dict_with_parameters(self): + component = OllamaGenerator( + model="llama2", + streaming_callback=print_streaming_chunk, + url="going_to_51_pegasi_b_for_weekend", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator", + "init_parameters": { + "timeout": 120, + "raw": False, + "template": None, + "system_prompt": None, + "model": "llama2", + "url": "going_to_51_pegasi_b_for_weekend", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + + def test_from_dict(self): + data = { + "type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator", + "init_parameters": { + "timeout": 120, + "raw": False, + "template": None, + "system_prompt": None, + "model": "llama2", + "url": "going_to_51_pegasi_b_for_weekend", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + component = OllamaGenerator.from_dict(data) + assert component.model == "llama2" + assert component.streaming_callback is print_streaming_chunk + assert component.url == "going_to_51_pegasi_b_for_weekend" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + @pytest.mark.parametrize( "configuration", [ @@ -86,7 +147,7 @@ def callback(x: StreamingChunk): ], ) @pytest.mark.parametrize("stream", [True, False]) - def test_create_json_payload(self, configuration, stream): + def test_create_json_payload(self, configuration: dict[str, Any], stream: bool): prompt = "hello" component = OllamaGenerator(**configuration)