From 9fab3e282788560676ee9a47758d5f01a5334a27 Mon Sep 17 00:00:00 2001 From: Alistair Rogers Date: Sat, 6 Jan 2024 16:22:10 +0000 Subject: [PATCH] add tests for init method --- .../ollama/src/ollama_haystack/__init__.py | 3 +-- .../ollama/tests/test_chat_generator.py | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/integrations/ollama/src/ollama_haystack/__init__.py b/integrations/ollama/src/ollama_haystack/__init__.py index 2405cf10e..10bc38121 100644 --- a/integrations/ollama/src/ollama_haystack/__init__.py +++ b/integrations/ollama/src/ollama_haystack/__init__.py @@ -2,8 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from ollama_haystack.generator import OllamaGenerator from ollama_haystack.chat.chat_generator import OllamaChatGenerator - +from ollama_haystack.generator import OllamaGenerator __all__ = ["OllamaGenerator", "OllamaChatGenerator"] diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index aa3751495..4827933cd 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -26,6 +26,29 @@ def list_of_chat_messages(user_chat_message, assistant_chat_message): class TestOllamaChatGenerator: + def test_init_default(self): + component = OllamaChatGenerator() + assert component.model == "orca-mini" + assert component.url == "http://localhost:11434/api/chat" + assert component.generation_kwargs == {} + assert component.template is None + assert component.timeout == 30 + assert component.streaming_callback is None + + def test_init(self): + component = OllamaChatGenerator( + model="llama2", + url="http://my-custom-endpoint:11434/api/chat", + generation_kwargs={"temperature": 0.5}, + timeout=5, + ) + + assert component.model == "llama2" + assert component.url == "http://my-custom-endpoint:11434/api/chat" + assert component.generation_kwargs == {"temperature": 0.5} + assert component.template is None + assert component.timeout == 5 + def test_user_message_to_dict(self, user_chat_message): observed = OllamaChatGenerator()._message_to_dict(user_chat_message) expected = {"role": "user", "content": "Tell me about why Super Mario is the greatest superhero"} @@ -101,4 +124,6 @@ def test_run(self): response = chat_generator.run([message]) + assert isinstance(response, dict) + assert isinstance(response["replies"], list) assert answer in response["replies"][0].content