From 025e5bd45d01da91058d50c063bece2acaff4960 Mon Sep 17 00:00:00 2001 From: Sermet Pekin Date: Wed, 21 Aug 2024 23:28:38 +0300 Subject: [PATCH] refactored --- .gitignore | 3 ++ evdschat/model/chatters.py | 58 +++++++++++++++++-------- tests/test_chat.py | 20 ++++----- tests/test_req.py | 87 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 28 deletions(-) create mode 100644 tests/test_req.py diff --git a/.gitignore b/.gitignore index e17a96e..2d5b7e6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ __pycache__/ *.env !example.env +test_pri_*.* + + poetry.lock *.xlsx *.docx diff --git a/evdschat/model/chatters.py b/evdschat/model/chatters.py index 9bb97d8..50a1834 100644 --- a/evdschat/model/chatters.py +++ b/evdschat/model/chatters.py @@ -54,7 +54,7 @@ class ModelAbstract(ABC): test = False def parse(self, prompt) -> dict[str, str]: - # assert len(str(self.api_key)) > 15 + return {"prompt": prompt, "model": self.model, "openai_api_key": self.api_key} def defaultOptions(self) -> str: @@ -91,37 +91,54 @@ def __str__(self): def mock_req(self, prompt) -> dict[str, str]: return global_mock() - def eval(self, kw: dict, permitted=None) -> Union[Tuple[Any, str], None]: - # print(kw.keys()) - nkw = {} + def check_permitted(self, key: str, permitted=None) -> bool: if not permitted: permitted = ["start_date", "aggregate", "frequency" "cache"] + return key in permitted + + def permitted_dict(self, kw: dict, permitted: None) -> Tuple[dict, str]: + new_dict = {} + for k, v in kw.items(): - if k in permitted: - nkw[k] = v + if self.check_permitted(k, permitted): + new_dict[k] = v - notes = "" - if "notes" in kw: - notes = kw["notes"] - del kw["notes"] + return new_dict + + def eval_real(self, kw, permitted=None) -> tuple[ResultChat, str]: + """eval_real""" + notes = kw.get("notes", "") try: - result = self.retrieve_fnc(kw["index"], **nkw) - res = create_result(result, status=Status.success) - return res, notes + result = self.retrieve_fnc( + kw["index"], **self.permitted_dict(kw, permitted) + ) + return create_result(result, status=Status.success), notes except Exception: traceback.print_exc() - return create_result(None, status=Status.failed, reason="Eval failed"), str("") + return self.failed_result() + + def failed_result(self): + """failed_result""" + return create_result(None, status=Status.failed, reason="Eval failed") + + def eval(self, kw: dict, permitted=None) -> Tuple[Any, str]: + """eval""" + index = kw.get("index", None) + + if not index: + return self.failed_result(), str("") + return self.eval_real(kw, permitted) def decide_caller(self): + """decide_caller""" if self.test: return self.mock_req if callable(c_caller_main): return self.post_c return self.post - def __call__(self, prompt, **kwargs) -> Union[Tuple[Any, str], None]: - import platform + def __call__(self, prompt, **kwargs) -> Union[Tuple[Any, str], bool]: if self.debug: return str(self) @@ -159,10 +176,15 @@ def _raise(self, *args): class OpenAI(ModelAbstract): """OpenAI""" - def post_c(self, p) -> dict[str, str]: - resp = c_caller_main(p, get_openai_key(), self.defaultOptions()) + def post_c(self, prompt: str , caller = c_caller_main ) -> dict[str, str]: + resp = caller(prompt, get_openai_key(), self.defaultOptions()) result_dict = json.loads(resp) r = result_dict["result"] res = json.loads(r) res["cache"] = False return res + + +@dataclass +class TestAI(OpenAI): + """TestAI""" diff --git a/tests/test_chat.py b/tests/test_chat.py index a497a23..7e658de 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -18,14 +18,14 @@ def test_chat(capsys): + with capsys.disabled() : + prompt = """ - prompt = """ - - Can I get reserves data please ? Aylık frekans istiyorum. ortalama olarak toplulaştırır mısın? - - - """ - with capsys.disabled(): - res, notes = chat(prompt, debug=False, force=False ) - print(res) - assert isinstance(res.data, pd.DataFrame) + Can I get reserves data please ? Aylık frekans istiyorum. ortalama olarak toplulaştırır mısın? + + + """ + with capsys.disabled(): + res, notes = chat(prompt, debug=False, force=False ) + print(res) + assert isinstance(res.data, pd.DataFrame) diff --git a/tests/test_req.py b/tests/test_req.py new file mode 100644 index 0000000..22cc95c --- /dev/null +++ b/tests/test_req.py @@ -0,0 +1,87 @@ +import json +import pytest +import ctypes +import os +import platform +from pathlib import Path +from importlib import resources +from typing import Union +import platform + +from evdschat.model.chatters import get_myapi_url, get_openai_key +from evdschat.model.chatters import TestAI, get_myapi_url + + +def get_exec_file(test=False) -> Path: + + executable_name = "libpost_request.so" + if platform.system() == "Windows": + executable_name = "libpost_request.dll" + return Path(".") / executable_name + + +def get_chatter(): + return TestAI() + + +@pytest.mark.skipif(not get_exec_file().exists(), reason="only tests locally") +def test_post(): + t = get_chatter() + resp = t.post(prompt="test") + # print(resp) + assert resp + + +@pytest.mark.skipif(not get_exec_file().exists(), reason="requires C executable") +def test_post_c(): + # t = get_chatter() + caller = get_c_fnc() + resp = caller(prompt="Loan data", api_key=get_openai_key(), url=get_myapi_url()) + result_dict = json.loads(resp) + r = result_dict["result"] + res = json.loads(r) + res["cache"] = False + assert res["index"] + + +def get_c_fnc(): + + class PostParams(ctypes.Structure): + _fields_ = [ + ("url", ctypes.c_char_p), + ("prompt", ctypes.c_char_p), + ("api_key", ctypes.c_char_p), + ("proxy_url", ctypes.c_char_p), + ] + + lib_path = get_exec_file() # check_c_executable() + if lib_path.exists(): + lib = ctypes.CDLL(lib_path) + + lib.post_request.argtypes = [ctypes.POINTER(PostParams)] + lib.post_request.restype = ctypes.c_char_p + + lib.free_memory.argtypes = [ctypes.c_void_p] + lib.free_memory.restype = None + + def c_caller(params): + response = lib.post_request(ctypes.byref(params)) + result = ctypes.string_at(response).decode("utf-8") + return result + + def c_caller_main(prompt, api_key, url, proxy=None): + prompt = prompt.replace("\n", " ") + + params = PostParams( + url=url.encode("utf-8"), + prompt=prompt.encode("utf-8"), + api_key=api_key.encode("utf-8"), + proxy_url=proxy.encode("utf-8") if proxy else None, + ) + + return c_caller(params) + + return c_caller_main + + +# test_post_c()