diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 5da725e87..55c6aa7b7 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/ollama-v1.1.0] - 2024-10-11 + +### 🚀 Features + +- Add `keep_alive` parameter to Ollama Generators (#1131) + +### ⚙️ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) + ## [integrations/ollama-v1.0.1] - 2024-09-26 ### 🐛 Bug Fixes diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 9502a187e..558fd593e 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk @@ -38,6 +38,7 @@ def __init__( url: str = "http://localhost:11434", generation_kwargs: Optional[Dict[str, Any]] = None, timeout: int = 120, + keep_alive: Optional[Union[float, str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -54,12 +55,21 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param keep_alive: + The option that controls how long the model will stay loaded into memory following the request. + If not set, it will use the default value from the Ollama (5 minutes). + The value can be set to: + - a duration string (such as "10m" or "24h") + - a number in seconds (such as 3600) + - any negative number which will keep the model loaded in memory (e.g. -1 or "-1m") + - '0' which will unload the model immediately after generating a response. """ self.timeout = timeout self.generation_kwargs = generation_kwargs or {} self.url = url self.model = model + self.keep_alive = keep_alive self.streaming_callback = streaming_callback self._client = Client(host=self.url, timeout=self.timeout) @@ -76,6 +86,7 @@ def to_dict(self) -> Dict[str, Any]: self, model=self.model, url=self.url, + keep_alive=self.keep_alive, generation_kwargs=self.generation_kwargs, timeout=self.timeout, streaming_callback=callback_name, @@ -165,7 +176,9 @@ def run( stream = self.streaming_callback is not None messages = [self._message_to_dict(message) for message in messages] - response = self._client.chat(model=self.model, messages=messages, stream=stream, options=generation_kwargs) + response = self._client.chat( + model=self.model, messages=messages, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs + ) if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index d92932c3e..058948e8a 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk @@ -36,6 +36,7 @@ def __init__( template: Optional[str] = None, raw: bool = False, timeout: int = 120, + keep_alive: Optional[Union[float, str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -59,6 +60,14 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param keep_alive: + The option that controls how long the model will stay loaded into memory following the request. + If not set, it will use the default value from the Ollama (5 minutes). + The value can be set to: + - a duration string (such as "10m" or "24h") + - a number in seconds (such as 3600) + - any negative number which will keep the model loaded in memory (e.g. -1 or "-1m") + - '0' which will unload the model immediately after generating a response. """ self.timeout = timeout self.raw = raw @@ -66,6 +75,7 @@ def __init__( self.system_prompt = system_prompt self.model = model self.url = url + self.keep_alive = keep_alive self.generation_kwargs = generation_kwargs or {} self.streaming_callback = streaming_callback @@ -87,6 +97,7 @@ def to_dict(self) -> Dict[str, Any]: system_prompt=self.system_prompt, model=self.model, url=self.url, + keep_alive=self.keep_alive, generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, ) @@ -172,7 +183,9 @@ def run( stream = self.streaming_callback is not None - response = self._client.generate(model=self.model, prompt=prompt, stream=stream, options=generation_kwargs) + response = self._client.generate( + model=self.model, prompt=prompt, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs + ) if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index a46758df3..5ac9289aa 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -26,12 +26,14 @@ def test_init_default(self): assert component.url == "http://localhost:11434" assert component.generation_kwargs == {} assert component.timeout == 120 + assert component.keep_alive is None def test_init(self): component = OllamaChatGenerator( model="llama2", url="http://my-custom-endpoint:11434", generation_kwargs={"temperature": 0.5}, + keep_alive="10m", timeout=5, ) @@ -39,6 +41,7 @@ def test_init(self): assert component.url == "http://my-custom-endpoint:11434" assert component.generation_kwargs == {"temperature": 0.5} assert component.timeout == 5 + assert component.keep_alive == "10m" def test_to_dict(self): component = OllamaChatGenerator( @@ -46,6 +49,7 @@ def test_to_dict(self): streaming_callback=print_streaming_chunk, url="custom_url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + keep_alive="5m", ) data = component.to_dict() assert data == { @@ -53,6 +57,7 @@ def test_to_dict(self): "init_parameters": { "timeout": 120, "model": "llama2", + "keep_alive": "5m", "url": "custom_url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -66,6 +71,7 @@ def test_from_dict(self): "timeout": 120, "model": "llama2", "url": "custom_url", + "keep_alive": "5m", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, @@ -75,6 +81,7 @@ def test_from_dict(self): assert component.streaming_callback is print_streaming_chunk assert component.url == "custom_url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.keep_alive == "5m" def test_build_message_from_ollama_response(self): model = "some_model" diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index c4c6906db..b02370234 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -45,6 +45,7 @@ def test_init_default(self): assert component.template is None assert component.raw is False assert component.timeout == 120 + assert component.keep_alive is None assert component.streaming_callback is None def test_init(self): @@ -57,6 +58,7 @@ def callback(x: StreamingChunk): generation_kwargs={"temperature": 0.5}, system_prompt="You are Luigi from Super Mario Bros.", timeout=5, + keep_alive="10m", streaming_callback=callback, ) assert component.model == "llama2" @@ -66,6 +68,7 @@ def callback(x: StreamingChunk): assert component.template is None assert component.raw is False assert component.timeout == 5 + assert component.keep_alive == "10m" assert component.streaming_callback == callback component = OllamaGenerator() @@ -80,6 +83,7 @@ def callback(x: StreamingChunk): "model": "orca-mini", "url": "http://localhost:11434", "streaming_callback": None, + "keep_alive": None, "generation_kwargs": {}, }, } @@ -89,6 +93,7 @@ def test_to_dict_with_parameters(self): model="llama2", streaming_callback=print_streaming_chunk, url="going_to_51_pegasi_b_for_weekend", + keep_alive="10m", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) data = component.to_dict() @@ -100,6 +105,7 @@ def test_to_dict_with_parameters(self): "template": None, "system_prompt": None, "model": "llama2", + "keep_alive": "10m", "url": "going_to_51_pegasi_b_for_weekend", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -115,6 +121,7 @@ def test_from_dict(self): "template": None, "system_prompt": None, "model": "llama2", + "keep_alive": "5m", "url": "going_to_51_pegasi_b_for_weekend", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -125,6 +132,7 @@ def test_from_dict(self): assert component.streaming_callback is print_streaming_chunk assert component.url == "going_to_51_pegasi_b_for_weekend" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.keep_alive == "5m" @pytest.mark.integration def test_ollama_generator_run_streaming(self):