From 98c92bfb2d5019a04cac609c492cb12f38d5037e Mon Sep 17 00:00:00 2001 From: sachinsachdeva <7625278+sachinsachdeva@users.noreply.github.com> Date: Sat, 27 Jan 2024 20:35:20 +0100 Subject: [PATCH 1/7] added support for streaming ollama generator --- .../components/generators/ollama/generator.py | 62 +++++++++++++++++-- integrations/ollama/tests/test_generator.py | 44 +++++++++++-- 2 files changed, 95 insertions(+), 11 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 55bd65d8a..99de73ff6 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -1,8 +1,10 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import requests from haystack import component +from haystack.dataclasses import StreamingChunk from requests import Response +import json @component @@ -21,6 +23,7 @@ def __init__( template: Optional[str] = None, raw: bool = False, timeout: int = 120, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ :param model: The name of the model to use. The model should be available in the running Ollama instance. @@ -36,6 +39,8 @@ def __init__( if you are specifying a full templated prompt in your API request. :param timeout: The number of seconds before throwing a timeout error from the Ollama API. Default is 120 seconds. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ self.timeout = timeout self.raw = raw @@ -44,8 +49,9 @@ def __init__( self.model = model self.url = url self.generation_kwargs = generation_kwargs or {} + self.streaming_callback = streaming_callback - def _create_json_payload(self, prompt: str, generation_kwargs=None) -> Dict[str, Any]: + def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]: """ Returns a dictionary of JSON arguments for a POST request to an Ollama service. :param prompt: The prompt to generate a response for. @@ -58,7 +64,7 @@ def _create_json_payload(self, prompt: str, generation_kwargs=None) -> Dict[str, return { "prompt": prompt, "model": self.model, - "stream": False, + "stream": stream, "raw": self.raw, "template": self.template, "system": self.system_prompt, @@ -71,12 +77,52 @@ def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, :param ollama_response: A response (requests library) from the Ollama API. :return: A dictionary of the returned responses and metadata. """ + resp_dict = ollama_response.json() replies = [resp_dict["response"]] meta = {key: value for key, value in resp_dict.items() if key != "response"} return {"replies": replies, "meta": [meta]} + + def _convert_to_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: + """ + Convert a list of chunks response required Haystack format. + :param chunks: List of StreamingChunks + :return: A dictionary of the returned responses and metadata. + """ + + replies = ["".join([c.content for c in chunks])] + meta = {key: value for key, value in chunks[0].meta.items() if key != "response"} + + return {"replies": replies, "meta": [meta]} + + def _handle_streaming_response(self, response) -> List[StreamingChunk]: + """ Handles Streaming response case + + :param response: streaming response from ollama api. + :return: The List[StreamingChunk]. + """ + chunks: List[StreamingChunk] = [] + for chunk in response.iter_lines(): + chunk_delta: StreamingChunk = self._build_chunk(chunk) + chunks.append(chunk_delta) + self.streaming_callback(chunk_delta) + return chunks + + def _build_chunk(self, chunk_response: Any) -> StreamingChunk: + """ + Converts the response from the Ollama API to a StreamingChunk. + :param chunk: The chunk returned by the Ollama API. + :return: The StreamingChunk. + """ + decoded_chunk = json.loads(chunk_response.decode("utf-8")) + + content = decoded_chunk['response'] + meta = {key: value for key, value in decoded_chunk.items() if key != "response"} + + chunk_message = StreamingChunk(content, meta) + return chunk_message @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) def run( @@ -94,11 +140,17 @@ def run( """ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - json_payload = self._create_json_payload(prompt, generation_kwargs) + stream = self.streaming_callback is not None + + json_payload = self._create_json_payload(prompt, stream, generation_kwargs) - response = requests.post(url=self.url, json=json_payload, timeout=self.timeout) + response = requests.post(url=self.url, json=json_payload, timeout=self.timeout, stream=stream) # throw error on unsuccessful response response.raise_for_status() + if stream: + chunks: List[StreamingChunk] = self._handle_streaming_response(response) + return self._convert_to_response(chunks) + return self._convert_to_haystack_response(response) diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 0403516ce..9a964f359 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -5,6 +5,8 @@ import pytest from haystack_integrations.components.generators.ollama import OllamaGenerator from requests import HTTPError +from haystack.dataclasses import StreamingChunk +import requests_mock class TestOllamaGenerator: @@ -42,16 +44,20 @@ def test_init_default(self): assert component.template is None assert component.raw is False assert component.timeout == 120 + assert component.streaming_callback is None def test_init(self): + def callback(x: StreamingChunk): + pass + component = OllamaGenerator( model="llama2", url="http://my-custom-endpoint:11434/api/generate", generation_kwargs={"temperature": 0.5}, system_prompt="You are Luigi from Super Mario Bros.", timeout=5, + streaming_callback = callback ) - assert component.model == "llama2" assert component.url == "http://my-custom-endpoint:11434/api/generate" assert component.generation_kwargs == {"temperature": 0.5} @@ -59,6 +65,7 @@ def test_init(self): assert component.template is None assert component.raw is False assert component.timeout == 5 + assert component.streaming_callback == callback @pytest.mark.parametrize( "configuration", @@ -68,27 +75,28 @@ def test_init(self): "url": "https://localhost:11434/api/generate", "raw": True, "system_prompt": "You are mario from Super Mario Bros.", - "template": None, + "template": None }, { "model": "some_model2", "url": "https://localhost:11434/api/generate", "raw": False, "system_prompt": None, - "template": "some template", + "template": "some template" }, ], ) - def test_create_json_payload(self, configuration): + @pytest.mark.parametrize("stream", [True, False]) + def test_create_json_payload(self, configuration, stream): prompt = "hello" component = OllamaGenerator(**configuration) - observed = component._create_json_payload(prompt=prompt) + observed = component._create_json_payload(prompt=prompt, stream=stream) expected = { "prompt": prompt, "model": configuration["model"], - "stream": False, + "stream": stream, "system": configuration["system_prompt"], "raw": configuration["raw"], "template": configuration["template"], @@ -96,3 +104,27 @@ def test_create_json_payload(self, configuration): } assert observed == expected + +class TestOllamaStreamingGenerator: + @pytest.mark.integration + def test_ollama_generator_run_streaming(self): + class Callback: + def __init__(self): + self.responses = "" + self.count_calls = 0 + + def __call__(self, chunk): + self.responses += chunk.content + self.count_calls += 1 + print(f'{self.count_calls} = {chunk}') + return chunk + + callback = Callback() + component = OllamaGenerator(streaming_callback=callback) + results = component.run(prompt="What's the capital of Netherlands?") + + assert len(results["replies"]) == 1 + assert "Amsterdam" in results["replies"][0] + assert len(results["meta"]) == 1 + assert callback.responses == results["replies"][0] + assert callback.count_calls > 1 \ No newline at end of file From 642e7f23e9d0f71ef75f1102ebcabd61bfe28f1f Mon Sep 17 00:00:00 2001 From: sachinsachdeva <7625278+sachinsachdeva@users.noreply.github.com> Date: Sat, 27 Jan 2024 20:46:11 +0100 Subject: [PATCH 2/7] fixed linting errors --- .../components/generators/ollama/generator.py | 10 +++++----- integrations/ollama/tests/test_generator.py | 8 +++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 99de73ff6..b603a9b14 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -1,10 +1,10 @@ +import json from typing import Any, Callable, Dict, List, Optional import requests from haystack import component from haystack.dataclasses import StreamingChunk from requests import Response -import json @component @@ -77,21 +77,21 @@ def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, :param ollama_response: A response (requests library) from the Ollama API. :return: A dictionary of the returned responses and metadata. """ - + resp_dict = ollama_response.json() replies = [resp_dict["response"]] meta = {key: value for key, value in resp_dict.items() if key != "response"} return {"replies": replies, "meta": [meta]} - + def _convert_to_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: """ Convert a list of chunks response required Haystack format. :param chunks: List of StreamingChunks :return: A dictionary of the returned responses and metadata. """ - + replies = ["".join([c.content for c in chunks])] meta = {key: value for key, value in chunks[0].meta.items() if key != "response"} @@ -118,7 +118,7 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ decoded_chunk = json.loads(chunk_response.decode("utf-8")) - content = decoded_chunk['response'] + content = decoded_chunk["response"] meta = {key: value for key, value in decoded_chunk.items() if key != "response"} chunk_message = StreamingChunk(content, meta) diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 9a964f359..55173276c 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -3,10 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.ollama import OllamaGenerator from requests import HTTPError -from haystack.dataclasses import StreamingChunk -import requests_mock class TestOllamaGenerator: @@ -48,7 +47,7 @@ def test_init_default(self): def test_init(self): def callback(x: StreamingChunk): - pass + x.content = "" component = OllamaGenerator( model="llama2", @@ -116,7 +115,6 @@ def __init__(self): def __call__(self, chunk): self.responses += chunk.content self.count_calls += 1 - print(f'{self.count_calls} = {chunk}') return chunk callback = Callback() @@ -127,4 +125,4 @@ def __call__(self, chunk): assert "Amsterdam" in results["replies"][0] assert len(results["meta"]) == 1 assert callback.responses == results["replies"][0] - assert callback.count_calls > 1 \ No newline at end of file + assert callback.count_calls > 1 From 81f0b90947dea050d76b2fcb38e7bbbf901eb941 Mon Sep 17 00:00:00 2001 From: sachinsachdeva <7625278+sachinsachdeva@users.noreply.github.com> Date: Sat, 27 Jan 2024 21:58:49 +0100 Subject: [PATCH 3/7] more linting issues --- .../components/generators/ollama/generator.py | 5 +++-- integrations/ollama/tests/test_generator.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index b603a9b14..f9906c3fb 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -98,7 +98,7 @@ def _convert_to_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[A return {"replies": replies, "meta": [meta]} def _handle_streaming_response(self, response) -> List[StreamingChunk]: - """ Handles Streaming response case + """Handles Streaming response case :param response: streaming response from ollama api. :return: The List[StreamingChunk]. @@ -107,7 +107,8 @@ def _handle_streaming_response(self, response) -> List[StreamingChunk]: for chunk in response.iter_lines(): chunk_delta: StreamingChunk = self._build_chunk(chunk) chunks.append(chunk_delta) - self.streaming_callback(chunk_delta) + if self.streaming_callback is not None: + self.streaming_callback(chunk_delta) return chunks def _build_chunk(self, chunk_response: Any) -> StreamingChunk: diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 55173276c..59d778333 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -55,7 +55,7 @@ def callback(x: StreamingChunk): generation_kwargs={"temperature": 0.5}, system_prompt="You are Luigi from Super Mario Bros.", timeout=5, - streaming_callback = callback + streaming_callback=callback, ) assert component.model == "llama2" assert component.url == "http://my-custom-endpoint:11434/api/generate" @@ -74,14 +74,14 @@ def callback(x: StreamingChunk): "url": "https://localhost:11434/api/generate", "raw": True, "system_prompt": "You are mario from Super Mario Bros.", - "template": None + "template": None, }, { "model": "some_model2", "url": "https://localhost:11434/api/generate", "raw": False, "system_prompt": None, - "template": "some template" + "template": "some template", }, ], ) @@ -104,6 +104,7 @@ def test_create_json_payload(self, configuration, stream): assert observed == expected + class TestOllamaStreamingGenerator: @pytest.mark.integration def test_ollama_generator_run_streaming(self): From a6a50e5a12e3129531b4b7ff2a0d981ebbf6fda2 Mon Sep 17 00:00:00 2001 From: sachinsachdeva <7625278+sachinsachdeva@users.noreply.github.com> Date: Tue, 6 Feb 2024 18:30:13 +0100 Subject: [PATCH 4/7] implemented serializing/deserializing for OllamaGenerator --- .../components/generators/ollama/generator.py | 34 +++++++++- integrations/ollama/tests/test_generator.py | 63 ++++++++++++++++++- 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index f9906c3fb..2582cf9c8 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -2,9 +2,10 @@ from typing import Any, Callable, Dict, List, Optional import requests -from haystack import component +from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk from requests import Response +from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler @component @@ -51,6 +52,37 @@ def __init__( self.generation_kwargs = generation_kwargs or {} self.streaming_callback = streaming_callback + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + :return: The serialized component as a dictionary. + """ + callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + timeout=self.timeout, + raw=self.raw, + template=self.template, + system_prompt=self.system_prompt, + model=self.model, + url=self.url, + generation_kwargs=self.generation_kwargs, + streaming_callback=callback_name, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + return default_from_dict(cls, data) + def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]: """ Returns a dictionary of JSON arguments for a POST request to an Ollama service. diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 59d778333..4bae7eb67 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -2,9 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Any +from _pytest.monkeypatch import MonkeyPatch import pytest from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.ollama import OllamaGenerator +from haystack.components.generators.utils import print_streaming_chunk from requests import HTTPError @@ -66,6 +69,64 @@ def callback(x: StreamingChunk): assert component.timeout == 5 assert component.streaming_callback == callback + component = OllamaGenerator() + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator", + "init_parameters": { + "timeout": 120, + "raw": False, + "template": None, + "system_prompt": None, + "model": "orca-mini", + "url": "http://localhost:11434/api/generate", + "streaming_callback": None, + "generation_kwargs": {}, + }, + } + + def test_to_dict_with_parameters(self): + component = OllamaGenerator( + model="llama2", + streaming_callback=print_streaming_chunk, + url="going_to_51_pegasi_b_for_weekend", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator", + "init_parameters": { + "timeout": 120, + "raw": False, + "template": None, + "system_prompt": None, + "model": "llama2", + "url": "going_to_51_pegasi_b_for_weekend", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + + def test_from_dict(self): + data = { + "type": "haystack_integrations.components.generators.ollama.generator.OllamaGenerator", + "init_parameters": { + "timeout": 120, + "raw": False, + "template": None, + "system_prompt": None, + "model": "llama2", + "url": "going_to_51_pegasi_b_for_weekend", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + component = OllamaGenerator.from_dict(data) + assert component.model == "llama2" + assert component.streaming_callback is print_streaming_chunk + assert component.url == "going_to_51_pegasi_b_for_weekend" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + @pytest.mark.parametrize( "configuration", [ @@ -86,7 +147,7 @@ def callback(x: StreamingChunk): ], ) @pytest.mark.parametrize("stream", [True, False]) - def test_create_json_payload(self, configuration, stream): + def test_create_json_payload(self, configuration: dict[str, Any], stream: bool): prompt = "hello" component = OllamaGenerator(**configuration) From 2d7aa9e93256c30eff7cb2d526d4a04e7cdc7067 Mon Sep 17 00:00:00 2001 From: sachinsachdeva <7625278+sachinsachdeva@users.noreply.github.com> Date: Tue, 6 Feb 2024 18:37:06 +0100 Subject: [PATCH 5/7] linting changes --- .../components/generators/ollama/generator.py | 4 ++-- integrations/ollama/tests/test_generator.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 2582cf9c8..62f7b11c1 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -3,9 +3,9 @@ import requests from haystack import component, default_from_dict, default_to_dict +from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler from haystack.dataclasses import StreamingChunk from requests import Response -from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler @component @@ -82,7 +82,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": if serialized_callback_handler: data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) return default_from_dict(cls, data) - + def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]: """ Returns a dictionary of JSON arguments for a POST request to an Ollama service. diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 4bae7eb67..5ef857e8c 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -3,11 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any -from _pytest.monkeypatch import MonkeyPatch + import pytest +from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.ollama import OllamaGenerator -from haystack.components.generators.utils import print_streaming_chunk from requests import HTTPError From 9a9a0dcec0b3b5c9ed21ff847d3897baa49aca7c Mon Sep 17 00:00:00 2001 From: sachinsachdeva <7625278+sachinsachdeva@users.noreply.github.com> Date: Fri, 9 Feb 2024 12:49:40 +0100 Subject: [PATCH 6/7] minor refactoring --- .../components/generators/ollama/generator.py | 8 ++++---- integrations/ollama/tests/test_generator.py | 4 +--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 62f7b11c1..321eab9f3 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -103,7 +103,7 @@ def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None "options": generation_kwargs, } - def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, List[Any]]: + def _convert_to_response(self, ollama_response: Response) -> Dict[str, List[Any]]: """ Convert a response from the Ollama API to the required Haystack format. :param ollama_response: A response (requests library) from the Ollama API. @@ -117,7 +117,7 @@ def _convert_to_haystack_response(self, ollama_response: Response) -> Dict[str, return {"replies": replies, "meta": [meta]} - def _convert_to_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: + def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: """ Convert a list of chunks response required Haystack format. :param chunks: List of StreamingChunks @@ -184,6 +184,6 @@ def run( if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) - return self._convert_to_response(chunks) + return self._convert_to_streaming_response(chunks) - return self._convert_to_haystack_response(response) + return self._convert_to_response(response) diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 5ef857e8c..478c33e6b 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -165,8 +165,6 @@ def test_create_json_payload(self, configuration: dict[str, Any], stream: bool): assert observed == expected - -class TestOllamaStreamingGenerator: @pytest.mark.integration def test_ollama_generator_run_streaming(self): class Callback: @@ -187,4 +185,4 @@ def __call__(self, chunk): assert "Amsterdam" in results["replies"][0] assert len(results["meta"]) == 1 assert callback.responses == results["replies"][0] - assert callback.count_calls > 1 + assert callback.count_calls > 1 \ No newline at end of file From c763f17df2b1a7aad566af4f187638613f579dea Mon Sep 17 00:00:00 2001 From: sachinsachdeva <7625278+sachinsachdeva@users.noreply.github.com> Date: Fri, 9 Feb 2024 13:01:26 +0100 Subject: [PATCH 7/7] formating --- integrations/ollama/tests/test_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 478c33e6b..4af2bdb82 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -185,4 +185,4 @@ def __call__(self, chunk): assert "Amsterdam" in results["replies"][0] assert len(results["meta"]) == 1 assert callback.responses == results["replies"][0] - assert callback.count_calls > 1 \ No newline at end of file + assert callback.count_calls > 1