Skip to content

Commit

Permalink
implemented serializing/deserializing for OllamaGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
sachinsachdeva committed Feb 6, 2024
1 parent 2c7a724 commit a6a50e5
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Any, Callable, Dict, List, Optional

import requests
from haystack import component
from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk
from requests import Response
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler


@component
Expand Down Expand Up @@ -51,6 +52,37 @@ def __init__(
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
timeout=self.timeout,
raw=self.raw,
template=self.template,
system_prompt=self.system_prompt,
model=self.model,
url=self.url,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
return default_from_dict(cls, data)

def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]:
"""
Returns a dictionary of JSON arguments for a POST request to an Ollama service.
Expand Down
63 changes: 62 additions & 1 deletion integrations/ollama/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from _pytest.monkeypatch import MonkeyPatch
import pytest
from haystack.dataclasses import StreamingChunk
from haystack_integrations.components.generators.ollama import OllamaGenerator
from haystack.components.generators.utils import print_streaming_chunk
from requests import HTTPError


Expand Down Expand Up @@ -66,6 +69,64 @@ def callback(x: StreamingChunk):
assert component.timeout == 5
assert component.streaming_callback == callback

component = OllamaGenerator()
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator",
"init_parameters": {
"timeout": 120,
"raw": False,
"template": None,
"system_prompt": None,
"model": "orca-mini",
"url": "http://localhost:11434/api/generate",
"streaming_callback": None,
"generation_kwargs": {},
},
}

def test_to_dict_with_parameters(self):
component = OllamaGenerator(
model="llama2",
streaming_callback=print_streaming_chunk,
url="going_to_51_pegasi_b_for_weekend",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator",
"init_parameters": {
"timeout": 120,
"raw": False,
"template": None,
"system_prompt": None,
"model": "llama2",
"url": "going_to_51_pegasi_b_for_weekend",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}

def test_from_dict(self):
data = {
"type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator",
"init_parameters": {
"timeout": 120,
"raw": False,
"template": None,
"system_prompt": None,
"model": "llama2",
"url": "going_to_51_pegasi_b_for_weekend",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = OllamaGenerator.from_dict(data)
assert component.model == "llama2"
assert component.streaming_callback is print_streaming_chunk
assert component.url == "going_to_51_pegasi_b_for_weekend"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

@pytest.mark.parametrize(
"configuration",
[
Expand All @@ -86,7 +147,7 @@ def callback(x: StreamingChunk):
],
)
@pytest.mark.parametrize("stream", [True, False])
def test_create_json_payload(self, configuration, stream):
def test_create_json_payload(self, configuration: dict[str, Any], stream: bool):
prompt = "hello"
component = OllamaGenerator(**configuration)

Expand Down

0 comments on commit a6a50e5

Please sign in to comment.