Skip to content

Commit

Permalink
add tests for init method
Browse files Browse the repository at this point in the history
  • Loading branch information
AlistairLR112 committed Jan 6, 2024
1 parent 657bfc3 commit 9fab3e2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
3 changes: 1 addition & 2 deletions integrations/ollama/src/ollama_haystack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from ollama_haystack.generator import OllamaGenerator
from ollama_haystack.chat.chat_generator import OllamaChatGenerator

from ollama_haystack.generator import OllamaGenerator

__all__ = ["OllamaGenerator", "OllamaChatGenerator"]
25 changes: 25 additions & 0 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,29 @@ def list_of_chat_messages(user_chat_message, assistant_chat_message):


class TestOllamaChatGenerator:
def test_init_default(self):
component = OllamaChatGenerator()
assert component.model == "orca-mini"
assert component.url == "http://localhost:11434/api/chat"
assert component.generation_kwargs == {}
assert component.template is None
assert component.timeout == 30
assert component.streaming_callback is None

def test_init(self):
component = OllamaChatGenerator(
model="llama2",
url="http://my-custom-endpoint:11434/api/chat",
generation_kwargs={"temperature": 0.5},
timeout=5,
)

assert component.model == "llama2"
assert component.url == "http://my-custom-endpoint:11434/api/chat"
assert component.generation_kwargs == {"temperature": 0.5}
assert component.template is None
assert component.timeout == 5

def test_user_message_to_dict(self, user_chat_message):
observed = OllamaChatGenerator()._message_to_dict(user_chat_message)
expected = {"role": "user", "content": "Tell me about why Super Mario is the greatest superhero"}
Expand Down Expand Up @@ -101,4 +124,6 @@ def test_run(self):

response = chat_generator.run([message])

assert isinstance(response, dict)
assert isinstance(response["replies"], list)
assert answer in response["replies"][0].content

0 comments on commit 9fab3e2

Please sign in to comment.