Skip to content

Commit

Permalink
lint with black
Browse files Browse the repository at this point in the history
  • Loading branch information
AlistairLR112 committed Dec 23, 2023
1 parent 7e85a8c commit 2b25fb1
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
Empty file added integrations/ollama/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions integrations/ollama/src/ollama_haystack/__about__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
__version__ = "0.0.1"
21 changes: 12 additions & 9 deletions integrations/ollama/src/ollama_haystack/generator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass, asdict
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Any, Dict, List, Union
from typing import Optional, Any, Dict, List
from urllib.parse import urljoin

import requests
from haystack import component


@dataclass
class OllamaResponse:
model: str
Expand All @@ -21,12 +22,13 @@ class OllamaResponse:
eval_duration: int

def __post_init__(self):
self.metadata = {key: value for key, value in self.__dict__ if key != 'response'}
self.metadata = {
key: value for key, value in self.__dict__.items() if key != "response"
}

def as_haystack_generator_response(self) -> Dict[str, List]:
"""Returns replies and metadata in the format required by haystack"""
return {'replies': [self.response],
'metadata': [self.metadata]}
return {"replies": [self.response], "metadata": [self.metadata]}


@component
Expand All @@ -53,10 +55,12 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
return {"model": self.model_name}

def _post_args(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]]):
def _post_args(self, prompt: str, generation_kwargs=None):
if generation_kwargs is None:
generation_kwargs = {}
return {
'url': urljoin(self.url, '/api/generate'),
'json': {
"url": urljoin(self.url, "/api/generate"),
"json": {
"prompt": prompt,
"model": self.model_name,
"stream": False,
Expand All @@ -67,7 +71,6 @@ def _post_args(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]]):

@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):

generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

response = requests.post(**self._post_args(prompt, generation_kwargs))
Expand Down

0 comments on commit 2b25fb1

Please sign in to comment.