Skip to content

Commit

Permalink
refactor post_args to json payload only
Browse files Browse the repository at this point in the history
  • Loading branch information
AlistairLR112 committed Jan 3, 2024
1 parent 0764f80 commit 57d0685
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
27 changes: 12 additions & 15 deletions integrations/ollama/src/ollama_haystack/generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

import requests
from haystack import component
Expand Down Expand Up @@ -56,26 +56,23 @@ def _get_telemetry_data(self) -> Dict[str, str]:
"""
return {"model": self.model_name}

def _post_args(self, prompt: str, generation_kwargs=None) -> Dict[str, Union[str, dict]]:
def _json_payload(self, prompt: str, generation_kwargs=None) -> Dict[str, Any]:
"""
Returns A dictionary of arguments for a POST request to an Ollama service
Returns A dictionary of JSON arguments for a POST request to an Ollama service
:param prompt: the prompt to generate a response for
:param generation_kwargs:
:return: A dictionary of arguments for a POST request to an Ollama service
"""
if generation_kwargs is None:
generation_kwargs = {}
return {
"url": self.url,
"json": {
"prompt": prompt,
"model": self.model_name,
"stream": False,
"raw": self.raw,
"template": self.template,
"system": self.system_prompt,
"options": generation_kwargs,
},
"prompt": prompt,
"model": self.model_name,
"stream": False,
"raw": self.raw,
"template": self.template,
"system": self.system_prompt,
"options": generation_kwargs,
}

@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
Expand All @@ -93,9 +90,9 @@ def run(
"""
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

post_arguments = self._post_args(prompt, generation_kwargs)
json_payload = self._json_payload(prompt, generation_kwargs)

response = requests.post(url=post_arguments["url"], json=post_arguments["json"], timeout=self.timeout)
response = requests.post(url=self.url, json=json_payload, timeout=self.timeout)

# Throw error on unsuccessful response
response.raise_for_status()
Expand Down
4 changes: 2 additions & 2 deletions integrations/ollama/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def test_init_default(self):
),
],
)
def test__post_args(self, configuration, prompt):
def test__json_payload(self, configuration, prompt):
component = OllamaGenerator(**configuration)

observed = component._post_args(prompt=prompt)
observed = component._json_payload(prompt=prompt)

expected = {
"url": "https://localhost:11434/api/generate",
Expand Down

0 comments on commit 57d0685

Please sign in to comment.