Skip to content

Commit

Permalink
refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
SermetPekin committed Aug 21, 2024
1 parent 3569cdf commit 025e5bd
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 28 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ __pycache__/
*.env
!example.env

test_pri_*.*


poetry.lock
*.xlsx
*.docx
Expand Down
58 changes: 40 additions & 18 deletions evdschat/model/chatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ModelAbstract(ABC):
test = False

def parse(self, prompt) -> dict[str, str]:
# assert len(str(self.api_key)) > 15

return {"prompt": prompt, "model": self.model, "openai_api_key": self.api_key}

def defaultOptions(self) -> str:
Expand Down Expand Up @@ -91,37 +91,54 @@ def __str__(self):
def mock_req(self, prompt) -> dict[str, str]:
return global_mock()

def eval(self, kw: dict, permitted=None) -> Union[Tuple[Any, str], None]:
# print(kw.keys())
nkw = {}
def check_permitted(self, key: str, permitted=None) -> bool:
if not permitted:
permitted = ["start_date", "aggregate", "frequency" "cache"]
return key in permitted

def permitted_dict(self, kw: dict, permitted: None) -> Tuple[dict, str]:
new_dict = {}

for k, v in kw.items():
if k in permitted:
nkw[k] = v
if self.check_permitted(k, permitted):
new_dict[k] = v

notes = ""
if "notes" in kw:
notes = kw["notes"]
del kw["notes"]
return new_dict

def eval_real(self, kw, permitted=None) -> tuple[ResultChat, str]:
"""eval_real"""
notes = kw.get("notes", "")
try:
result = self.retrieve_fnc(kw["index"], **nkw)
res = create_result(result, status=Status.success)
return res, notes
result = self.retrieve_fnc(
kw["index"], **self.permitted_dict(kw, permitted)
)
return create_result(result, status=Status.success), notes
except Exception:
traceback.print_exc()

return create_result(None, status=Status.failed, reason="Eval failed"), str("")
return self.failed_result()

def failed_result(self):
"""failed_result"""
return create_result(None, status=Status.failed, reason="Eval failed")

def eval(self, kw: dict, permitted=None) -> Tuple[Any, str]:
"""eval"""
index = kw.get("index", None)

if not index:
return self.failed_result(), str("")
return self.eval_real(kw, permitted)

def decide_caller(self):
"""decide_caller"""
if self.test:
return self.mock_req
if callable(c_caller_main):
return self.post_c
return self.post

def __call__(self, prompt, **kwargs) -> Union[Tuple[Any, str], None]:
import platform
def __call__(self, prompt, **kwargs) -> Union[Tuple[Any, str], bool]:

if self.debug:
return str(self)
Expand Down Expand Up @@ -159,10 +176,15 @@ def _raise(self, *args):
class OpenAI(ModelAbstract):
"""OpenAI"""

def post_c(self, p) -> dict[str, str]:
resp = c_caller_main(p, get_openai_key(), self.defaultOptions())
def post_c(self, prompt: str , caller = c_caller_main ) -> dict[str, str]:
resp = caller(prompt, get_openai_key(), self.defaultOptions())
result_dict = json.loads(resp)
r = result_dict["result"]
res = json.loads(r)
res["cache"] = False
return res


@dataclass
class TestAI(OpenAI):
"""TestAI"""
20 changes: 10 additions & 10 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@


def test_chat(capsys):
with capsys.disabled() :
prompt = """
prompt = """
Can I get reserves data please ? Aylık frekans istiyorum. ortalama olarak toplulaştırır mısın?
"""
with capsys.disabled():
res, notes = chat(prompt, debug=False, force=False )
print(res)
assert isinstance(res.data, pd.DataFrame)
Can I get reserves data please ? Aylık frekans istiyorum. ortalama olarak toplulaştırır mısın?
"""
with capsys.disabled():
res, notes = chat(prompt, debug=False, force=False )
print(res)
assert isinstance(res.data, pd.DataFrame)
87 changes: 87 additions & 0 deletions tests/test_req.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
import pytest
import ctypes
import os
import platform
from pathlib import Path
from importlib import resources
from typing import Union
import platform

from evdschat.model.chatters import get_myapi_url, get_openai_key
from evdschat.model.chatters import TestAI, get_myapi_url


def get_exec_file(test=False) -> Path:

executable_name = "libpost_request.so"
if platform.system() == "Windows":
executable_name = "libpost_request.dll"
return Path(".") / executable_name


def get_chatter():
return TestAI()


@pytest.mark.skipif(not get_exec_file().exists(), reason="only tests locally")
def test_post():
t = get_chatter()
resp = t.post(prompt="test")
# print(resp)
assert resp


@pytest.mark.skipif(not get_exec_file().exists(), reason="requires C executable")
def test_post_c():
# t = get_chatter()
caller = get_c_fnc()
resp = caller(prompt="Loan data", api_key=get_openai_key(), url=get_myapi_url())
result_dict = json.loads(resp)
r = result_dict["result"]
res = json.loads(r)
res["cache"] = False
assert res["index"]


def get_c_fnc():

class PostParams(ctypes.Structure):
_fields_ = [
("url", ctypes.c_char_p),
("prompt", ctypes.c_char_p),
("api_key", ctypes.c_char_p),
("proxy_url", ctypes.c_char_p),
]

lib_path = get_exec_file() # check_c_executable()
if lib_path.exists():
lib = ctypes.CDLL(lib_path)

lib.post_request.argtypes = [ctypes.POINTER(PostParams)]
lib.post_request.restype = ctypes.c_char_p

lib.free_memory.argtypes = [ctypes.c_void_p]
lib.free_memory.restype = None

def c_caller(params):
response = lib.post_request(ctypes.byref(params))
result = ctypes.string_at(response).decode("utf-8")
return result

def c_caller_main(prompt, api_key, url, proxy=None):
prompt = prompt.replace("\n", " ")

params = PostParams(
url=url.encode("utf-8"),
prompt=prompt.encode("utf-8"),
api_key=api_key.encode("utf-8"),
proxy_url=proxy.encode("utf-8") if proxy else None,
)

return c_caller(params)

return c_caller_main


# test_post_c()

0 comments on commit 025e5bd

Please sign in to comment.