Skip to content

Commit

Permalink
clean implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kostis Gourgoulias committed Jun 20, 2024
1 parent 66e5db2 commit 2d579ee
Showing 1 changed file with 11 additions and 69 deletions.
80 changes: 11 additions & 69 deletions textgrad/engine/local_model_openai_api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
import platformdirs
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
import json
from .base import EngineLM, CachedEngine
import logging
from .openai import ChatOpenAI

logger = logging.getLogger(__name__)

class ChatExternalClient(EngineLM, CachedEngine):

class ChatExternalClient(ChatOpenAI):
"""
This is the same as engine.openai.ChatOpenAI, but we pass the
client explicitly to the constructor.
Expand All @@ -30,66 +26,12 @@ def __init__(
:param model_string:
:param system_prompt:
"""
self.client = client
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_local_model_{model_string}.db")

super().__init__(cache_path=cache_path)

self.system_prompt = system_prompt
self.model_string = model_string

def generate(
self, prompt, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
):
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt

cache_or_none = self._check_cache(sys_prompt_arg + prompt)
if cache_or_none is not None:
return cache_or_none
if os.getenv("OPENAI_API_KEY") is None:
logger.warning("OPENAI_API_KEY not set. Setting it from client.")
os.environ["OPENAI_API_KEY"] = client.api_key

response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "system", "content": sys_prompt_arg},
{"role": "user", "content": prompt},
],
frequency_penalty=0,
presence_penalty=0,
stop=None,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
super().__init__(
model_string=model_string, system_prompt=system_prompt, **kwargs
)

response = response.choices[0].message.content
self._save_cache(sys_prompt_arg + prompt, response)
return response

def generate_with_messages(
self, messages, temperature=0, max_tokens=2000, top_p=0.99
):
prompt = json.dumps(messages)

cache_or_none = self._check_cache(prompt)
if cache_or_none is not None:
return cache_or_none

response = self.client.chat.completions.create(
model=self.model_string,
messages=messages,
frequency_penalty=0,
presence_penalty=0,
stop=None,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

response = response.choices[0].message.content
self._save_cache(prompt, response)
return response

@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
def __call__(self, prompt, **kwargs):
return self.generate(prompt, **kwargs)
self.client = client

0 comments on commit 2d579ee

Please sign in to comment.