Skip to content

Commit

Permalink
feat: add keep_alive parameter to Ollama Generators (#1131)
Browse files Browse the repository at this point in the history
* feat: add keep_alive parameter to Ollama integrations

* style: run linter

* fix: serialize keep_alive parameters

* test: include keep_alive parameter in tests

* docs: add keep_alive usage to the docstring

* style: I keep forgetting to lint

* style: update docs

* small fixes to docstrings

---------

Co-authored-by: anakin87 <[email protected]>
  • Loading branch information
emso-c and anakin87 authored Oct 11, 2024
1 parent 9438634 commit 9f1fb94
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, StreamingChunk
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(
url: str = "http://localhost:11434",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
keep_alive: Optional[Union[float, str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Expand All @@ -54,12 +55,21 @@ def __init__(
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param keep_alive:
The option that controls how long the model will stay loaded into memory following the request.
If not set, it will use the default value from the Ollama (5 minutes).
The value can be set to:
- a duration string (such as "10m" or "24h")
- a number in seconds (such as 3600)
- any negative number which will keep the model loaded in memory (e.g. -1 or "-1m")
- '0' which will unload the model immediately after generating a response.
"""

self.timeout = timeout
self.generation_kwargs = generation_kwargs or {}
self.url = url
self.model = model
self.keep_alive = keep_alive
self.streaming_callback = streaming_callback

self._client = Client(host=self.url, timeout=self.timeout)
Expand All @@ -76,6 +86,7 @@ def to_dict(self) -> Dict[str, Any]:
self,
model=self.model,
url=self.url,
keep_alive=self.keep_alive,
generation_kwargs=self.generation_kwargs,
timeout=self.timeout,
streaming_callback=callback_name,
Expand Down Expand Up @@ -165,7 +176,9 @@ def run(

stream = self.streaming_callback is not None
messages = [self._message_to_dict(message) for message in messages]
response = self._client.chat(model=self.model, messages=messages, stream=stream, options=generation_kwargs)
response = self._client.chat(
model=self.model, messages=messages, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs
)

if stream:
chunks: List[StreamingChunk] = self._handle_streaming_response(response)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk
Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(
template: Optional[str] = None,
raw: bool = False,
timeout: int = 120,
keep_alive: Optional[Union[float, str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Expand All @@ -59,13 +60,22 @@ def __init__(
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param keep_alive:
The option that controls how long the model will stay loaded into memory following the request.
If not set, it will use the default value from the Ollama (5 minutes).
The value can be set to:
- a duration string (such as "10m" or "24h")
- a number in seconds (such as 3600)
- any negative number which will keep the model loaded in memory (e.g. -1 or "-1m")
- '0' which will unload the model immediately after generating a response.
"""
self.timeout = timeout
self.raw = raw
self.template = template
self.system_prompt = system_prompt
self.model = model
self.url = url
self.keep_alive = keep_alive
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback

Expand All @@ -87,6 +97,7 @@ def to_dict(self) -> Dict[str, Any]:
system_prompt=self.system_prompt,
model=self.model,
url=self.url,
keep_alive=self.keep_alive,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
)
Expand Down Expand Up @@ -172,7 +183,9 @@ def run(

stream = self.streaming_callback is not None

response = self._client.generate(model=self.model, prompt=prompt, stream=stream, options=generation_kwargs)
response = self._client.generate(
model=self.model, prompt=prompt, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs
)

if stream:
chunks: List[StreamingChunk] = self._handle_streaming_response(response)
Expand Down
7 changes: 7 additions & 0 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,38 @@ def test_init_default(self):
assert component.url == "http://localhost:11434"
assert component.generation_kwargs == {}
assert component.timeout == 120
assert component.keep_alive is None

def test_init(self):
component = OllamaChatGenerator(
model="llama2",
url="http://my-custom-endpoint:11434",
generation_kwargs={"temperature": 0.5},
keep_alive="10m",
timeout=5,
)

assert component.model == "llama2"
assert component.url == "http://my-custom-endpoint:11434"
assert component.generation_kwargs == {"temperature": 0.5}
assert component.timeout == 5
assert component.keep_alive == "10m"

def test_to_dict(self):
component = OllamaChatGenerator(
model="llama2",
streaming_callback=print_streaming_chunk,
url="custom_url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
keep_alive="5m",
)
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
"init_parameters": {
"timeout": 120,
"model": "llama2",
"keep_alive": "5m",
"url": "custom_url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -66,6 +71,7 @@ def test_from_dict(self):
"timeout": 120,
"model": "llama2",
"url": "custom_url",
"keep_alive": "5m",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
Expand All @@ -75,6 +81,7 @@ def test_from_dict(self):
assert component.streaming_callback is print_streaming_chunk
assert component.url == "custom_url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.keep_alive == "5m"

def test_build_message_from_ollama_response(self):
model = "some_model"
Expand Down
8 changes: 8 additions & 0 deletions integrations/ollama/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_init_default(self):
assert component.template is None
assert component.raw is False
assert component.timeout == 120
assert component.keep_alive is None
assert component.streaming_callback is None

def test_init(self):
Expand All @@ -57,6 +58,7 @@ def callback(x: StreamingChunk):
generation_kwargs={"temperature": 0.5},
system_prompt="You are Luigi from Super Mario Bros.",
timeout=5,
keep_alive="10m",
streaming_callback=callback,
)
assert component.model == "llama2"
Expand All @@ -66,6 +68,7 @@ def callback(x: StreamingChunk):
assert component.template is None
assert component.raw is False
assert component.timeout == 5
assert component.keep_alive == "10m"
assert component.streaming_callback == callback

component = OllamaGenerator()
Expand All @@ -80,6 +83,7 @@ def callback(x: StreamingChunk):
"model": "orca-mini",
"url": "http://localhost:11434",
"streaming_callback": None,
"keep_alive": None,
"generation_kwargs": {},
},
}
Expand All @@ -89,6 +93,7 @@ def test_to_dict_with_parameters(self):
model="llama2",
streaming_callback=print_streaming_chunk,
url="going_to_51_pegasi_b_for_weekend",
keep_alive="10m",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
data = component.to_dict()
Expand All @@ -100,6 +105,7 @@ def test_to_dict_with_parameters(self):
"template": None,
"system_prompt": None,
"model": "llama2",
"keep_alive": "10m",
"url": "going_to_51_pegasi_b_for_weekend",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -115,6 +121,7 @@ def test_from_dict(self):
"template": None,
"system_prompt": None,
"model": "llama2",
"keep_alive": "5m",
"url": "going_to_51_pegasi_b_for_weekend",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -125,6 +132,7 @@ def test_from_dict(self):
assert component.streaming_callback is print_streaming_chunk
assert component.url == "going_to_51_pegasi_b_for_weekend"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.keep_alive == "5m"

@pytest.mark.integration
def test_ollama_generator_run_streaming(self):
Expand Down

0 comments on commit 9f1fb94

Please sign in to comment.