From 57d068500a906d0ab01f6e73941f81294ddf2071 Mon Sep 17 00:00:00 2001 From: Alistair Rogers Date: Wed, 3 Jan 2024 15:14:55 +0000 Subject: [PATCH] refactor post_args to json payload only --- .../ollama/src/ollama_haystack/generator.py | 27 +++++++++---------- integrations/ollama/tests/test_generator.py | 4 +-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/integrations/ollama/src/ollama_haystack/generator.py b/integrations/ollama/src/ollama_haystack/generator.py index a7f3c5a97..2f56e2971 100644 --- a/integrations/ollama/src/ollama_haystack/generator.py +++ b/integrations/ollama/src/ollama_haystack/generator.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import requests from haystack import component @@ -56,9 +56,9 @@ def _get_telemetry_data(self) -> Dict[str, str]: """ return {"model": self.model_name} - def _post_args(self, prompt: str, generation_kwargs=None) -> Dict[str, Union[str, dict]]: + def _json_payload(self, prompt: str, generation_kwargs=None) -> Dict[str, Any]: """ - Returns A dictionary of arguments for a POST request to an Ollama service + Returns A dictionary of JSON arguments for a POST request to an Ollama service :param prompt: the prompt to generate a response for :param generation_kwargs: :return: A dictionary of arguments for a POST request to an Ollama service @@ -66,16 +66,13 @@ def _post_args(self, prompt: str, generation_kwargs=None) -> Dict[str, Union[str if generation_kwargs is None: generation_kwargs = {} return { - "url": self.url, - "json": { - "prompt": prompt, - "model": self.model_name, - "stream": False, - "raw": self.raw, - "template": self.template, - "system": self.system_prompt, - "options": generation_kwargs, - }, + "prompt": prompt, + "model": self.model_name, + "stream": False, + "raw": self.raw, + "template": self.template, + "system": self.system_prompt, + "options": generation_kwargs, } @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) @@ -93,9 +90,9 @@ def run( """ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - post_arguments = self._post_args(prompt, generation_kwargs) + json_payload = self._json_payload(prompt, generation_kwargs) - response = requests.post(url=post_arguments["url"], json=post_arguments["json"], timeout=self.timeout) + response = requests.post(url=self.url, json=json_payload, timeout=self.timeout) # Throw error on unsuccessful response response.raise_for_status() diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index fa15b8360..46e880d93 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -69,10 +69,10 @@ def test_init_default(self): ), ], ) - def test__post_args(self, configuration, prompt): + def test__json_payload(self, configuration, prompt): component = OllamaGenerator(**configuration) - observed = component._post_args(prompt=prompt) + observed = component._json_payload(prompt=prompt) expected = { "url": "https://localhost:11434/api/generate",