From 2b25fb1b6c23b9cf5e0ed8b0e5dff6fed194a254 Mon Sep 17 00:00:00 2001 From: Alistair Rogers Date: Sat, 23 Dec 2023 17:55:04 +0000 Subject: [PATCH] lint with black --- integrations/ollama/__init__.py | 0 .../ollama/src/ollama_haystack/__about__.py | 4 ++++ .../ollama/src/ollama_haystack/generator.py | 21 +++++++++++-------- 3 files changed, 16 insertions(+), 9 deletions(-) create mode 100644 integrations/ollama/__init__.py create mode 100644 integrations/ollama/src/ollama_haystack/__about__.py diff --git a/integrations/ollama/__init__.py b/integrations/ollama/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/ollama/src/ollama_haystack/__about__.py b/integrations/ollama/src/ollama_haystack/__about__.py new file mode 100644 index 000000000..0e4fa27cf --- /dev/null +++ b/integrations/ollama/src/ollama_haystack/__about__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +__version__ = "0.0.1" diff --git a/integrations/ollama/src/ollama_haystack/generator.py b/integrations/ollama/src/ollama_haystack/generator.py index 0f6db4564..f748118e6 100644 --- a/integrations/ollama/src/ollama_haystack/generator.py +++ b/integrations/ollama/src/ollama_haystack/generator.py @@ -1,11 +1,12 @@ -from dataclasses import dataclass, asdict +from dataclasses import dataclass from datetime import datetime -from typing import Optional, Any, Dict, List, Union +from typing import Optional, Any, Dict, List from urllib.parse import urljoin import requests from haystack import component + @dataclass class OllamaResponse: model: str @@ -21,12 +22,13 @@ class OllamaResponse: eval_duration: int def __post_init__(self): - self.metadata = {key: value for key, value in self.__dict__ if key != 'response'} + self.metadata = { + key: value for key, value in self.__dict__.items() if key != "response" + } def as_haystack_generator_response(self) -> Dict[str, List]: """Returns replies and metadata in the format required by haystack""" - return {'replies': [self.response], - 'metadata': [self.metadata]} + return {"replies": [self.response], "metadata": [self.metadata]} @component @@ -53,10 +55,12 @@ def _get_telemetry_data(self) -> Dict[str, Any]: """ return {"model": self.model_name} - def _post_args(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]]): + def _post_args(self, prompt: str, generation_kwargs=None): + if generation_kwargs is None: + generation_kwargs = {} return { - 'url': urljoin(self.url, '/api/generate'), - 'json': { + "url": urljoin(self.url, "/api/generate"), + "json": { "prompt": prompt, "model": self.model_name, "stream": False, @@ -67,7 +71,6 @@ def _post_args(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]]): @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): - generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} response = requests.post(**self._post_args(prompt, generation_kwargs))