Skip to content

Commit

Permalink
trying to make engine implementation independent from cache
Browse files Browse the repository at this point in the history
  • Loading branch information
vinid committed Jul 14, 2024
1 parent a4d4824 commit 0ba66ad
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 1 deletion.
70 changes: 70 additions & 0 deletions textgrad/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import hashlib
import diskcache as dc
from abc import ABC, abstractmethod
from typing import Union, List
import json

class EngineLM(ABC):
system_prompt: str = "You are a helpful, creative, and smart assistant."
Expand Down Expand Up @@ -41,3 +43,71 @@ def __setstate__(self, state):
# Restore the cache after unpickling
self.__dict__.update(state)
self.cache = dc.Cache(self.cache_path)

import platformdirs
import os

class CachedLLM(CachedEngine, EngineLM):
def __init__(self, model_string, is_multimodal=False, do_cache=False):
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")

super().__init__(cache_path=cache_path)
self.model_string = model_string
self.is_multimodal = is_multimodal
self.do_cache = do_cache

def __call__(self, prompt, **kwargs):
return self.generate(prompt, **kwargs)

@abstractmethod
def _generate_from_single_prompt(self, prompt: str, system_prompt: str=None, **kwargs):
pass

@abstractmethod
def _generate_from_multiple_input(self, content: List[Union[str, bytes]], system_prompt: str=None, **kwargs):
pass

def single_prompt_generate(self, prompt: str, system_prompt: str=None, **kwargs):
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt

if self.do_cache:
cache_or_none = self._check_cache(sys_prompt_arg + prompt)
if cache_or_none is not None:
return cache_or_none

response = self._generate_from_single_prompt(prompt, system_prompt=sys_prompt_arg, **kwargs)

if self.do_cache:
self._save_cache(sys_prompt_arg + prompt, response)
return response

def multimodal_generate(self, content: List[Union[str, bytes]], system_prompt: str = None, **kwargs):

sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
if self.do_cache:
key = "".join([str(k) for k in content])

cache_key = sys_prompt_arg + key
cache_or_none = self._check_cache(cache_key)
if cache_or_none is not None:
return cache_or_none

response = self._generate_from_multiple_input(content, system_prompt=sys_prompt_arg, **kwargs)

if self.do_cache:
self._save_cache(cache_key, response)

return response

def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str = None, **kwargs):
if isinstance(content, str):
return self.single_prompt_generate(content, system_prompt=system_prompt, **kwargs)

elif isinstance(content, list):
has_multimodal_input = any(isinstance(item, bytes) for item in content)
if has_multimodal_input and not self.is_multimodal:
raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.")

return self.multimodal_generate(content, system_prompt=system_prompt, **kwargs)

88 changes: 87 additions & 1 deletion textgrad/engine/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
)
from typing import List, Union

from .base import EngineLM, CachedEngine

from .base import EngineLM, CachedEngine, CachedLLM
from .engine_utils import get_image_type_from_bytes

# Default base URL for OLLAMA
Expand Down Expand Up @@ -158,6 +159,91 @@ def _generate_from_multiple_input(
self._save_cache(cache_key, response_text)
return response_text


class OpenAIWithCachedLLM(CachedLLM):
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."

def __init__(self, model_string, is_multimodal=False, system_prompt: str = DEFAULT_SYSTEM_PROMPT, do_cache=False):
super().__init__(model_string=model_string, is_multimodal=is_multimodal, do_cache=do_cache)
"""
:param model_string:
:param system_prompt:
:param base_url: Used to support Ollama
"""

self.system_prompt = system_prompt

if os.getenv("OPENAI_API_KEY") is None:
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")

self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)

def _generate_from_single_prompt(
self, prompt: str, system_prompt: str= None, temperature=0, max_tokens=2000, top_p=0.99
):

response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
frequency_penalty=0,
presence_penalty=0,
stop=None,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

response = response.choices[0].message.content
return response

def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]:
"""Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API.
"""
formatted_content = []
for item in content:
if isinstance(item, bytes):
# For now, bytes are assumed to be images
image_type = get_image_type_from_bytes(item)
base64_image = base64.b64encode(item).decode('utf-8')
formatted_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/{image_type};base64,{base64_image}"
}
})
elif isinstance(item, str):
formatted_content.append({
"type": "text",
"text": item
})
else:
raise ValueError(f"Unsupported input type: {type(item)}")
return formatted_content

def _generate_from_multiple_input(
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
):
formatted_content = self._format_content(content)

response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": formatted_content},
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

response_text = response.choices[0].message.content
return response_text

class AzureChatOpenAI(ChatOpenAI):
def __init__(
self,
Expand Down

0 comments on commit 0ba66ad

Please sign in to comment.