From 9a9a0dcec0b3b5c9ed21ff847d3897baa49aca7c Mon Sep 17 00:00:00 2001 From: sachinsachdeva <7625278+sachinsachdeva@users.noreply.github.com> Date: Fri, 9 Feb 2024 12:49:40 +0100 Subject: [PATCH] minor refactoring --- .../components/generators/ollama/generator.py | 8 ++++---- integrations/ollama/tests/test_generator.py | 4 +--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 62f7b11c1..321eab9f3 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -103,7 +103,7 @@ def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None "options": generation_kwargs, } - def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, List[Any]]: + def _convert_to_response(self, ollama_response: Response) -> Dict[str, List[Any]]: """ Convert a response from the Ollama API to the required Haystack format. :param ollama_response: A response (requests library) from the Ollama API. @@ -117,7 +117,7 @@ def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, return {"replies": replies, "meta": [meta]} - def _convert_to_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: + def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: """ Convert a list of chunks response required Haystack format. :param chunks: List of StreamingChunks @@ -184,6 +184,6 @@ def run( if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) - return self._convert_to_response(chunks) + return self._convert_to_streaming_response(chunks) - return self._convert_to_haystack_response(response) + return self._convert_to_response(response) diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 5ef857e8c..478c33e6b 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -165,8 +165,6 @@ def test_create_json_payload(self, configuration: dict[str, Any], stream: bool): assert observed == expected - -class TestOllamaStreamingGenerator: @pytest.mark.integration def test_ollama_generator_run_streaming(self): class Callback: @@ -187,4 +185,4 @@ def __call__(self, chunk): assert "Amsterdam" in results["replies"][0] assert len(results["meta"]) == 1 assert callback.responses == results["replies"][0] - assert callback.count_calls > 1 + assert callback.count_calls > 1 \ No newline at end of file