Skip to content

Commit

Permalink
minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
sachinsachdeva committed Feb 9, 2024
1 parent 379b12b commit 9a9a0dc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None
"options": generation_kwargs,
}

def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, List[Any]]:
def _convert_to_response(self, ollama_response: Response) -> Dict[str, List[Any]]:
"""
Convert a response from the Ollama API to the required Haystack format.
:param ollama_response: A response (requests library) from the Ollama API.
Expand All @@ -117,7 +117,7 @@ def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str,

return {"replies": replies, "meta": [meta]}

def _convert_to_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
"""
Convert a list of chunks response required Haystack format.
:param chunks: List of StreamingChunks
Expand Down Expand Up @@ -184,6 +184,6 @@ def run(

if stream:
chunks: List[StreamingChunk] = self._handle_streaming_response(response)
return self._convert_to_response(chunks)
return self._convert_to_streaming_response(chunks)

return self._convert_to_haystack_response(response)
return self._convert_to_response(response)
4 changes: 1 addition & 3 deletions integrations/ollama/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ def test_create_json_payload(self, configuration: dict[str, Any], stream: bool):

assert observed == expected


class TestOllamaStreamingGenerator:
@pytest.mark.integration
def test_ollama_generator_run_streaming(self):
class Callback:
Expand All @@ -187,4 +185,4 @@ def __call__(self, chunk):
assert "Amsterdam" in results["replies"][0]
assert len(results["meta"]) == 1
assert callback.responses == results["replies"][0]
assert callback.count_calls > 1
assert callback.count_calls > 1

0 comments on commit 9a9a0dc

Please sign in to comment.