-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
230 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
import hashlib | ||
|
||
try: | ||
from openai import OpenAI, AzureOpenAI | ||
except ImportError: | ||
raise ImportError("If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables.") | ||
|
||
import os | ||
import json | ||
from abc import ABC, abstractmethod | ||
import base64 | ||
import platformdirs | ||
from tenacity import ( | ||
retry, | ||
stop_after_attempt, | ||
wait_random_exponential, | ||
) | ||
from typing import List, Union | ||
from functools import wraps | ||
from .engine_utils import get_image_type_from_bytes | ||
|
||
def cached(func): | ||
@wraps(func) | ||
def wrapper(self, *args, **kwargs): | ||
if not self.use_cache: | ||
return func(self, *args, **kwargs) | ||
|
||
# get string representation from args and kwargs | ||
key = hash(str(args) + str(kwargs)) | ||
key = hashlib.sha256(f"{key}".encode()).hexdigest() | ||
|
||
if key in self.cache: | ||
return self.cache[key] | ||
|
||
result = func(self, *args, **kwargs) | ||
self.cache[key] = result | ||
return result | ||
|
||
return wrapper | ||
|
||
|
||
class EngineLM(ABC): | ||
system_prompt: str = "You are a helpful, creative, and smart assistant." | ||
model_string: str | ||
is_multimodal: bool | ||
use_cache: bool = False | ||
cache_path: str | ||
|
||
@abstractmethod | ||
def _generate_from_multiple_input(self, prompt, system_prompt=None, **kwargs) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
def _generate_from_single_prompt(self, prompt, system_prompt=None, **kwargs) -> str: | ||
pass | ||
|
||
# TBF this could be simplified to a single generate method | ||
def _prepare_generate_from_single_prompt(self, prompt: str, system_prompt: str = None, **kwargs): | ||
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt | ||
return self._generate_from_single_prompt(prompt, system_prompt=sys_prompt_arg, **kwargs) | ||
|
||
def _prepare_generate_from_multiple_input(self, content: List[Union[str, bytes]], system_prompt=None, **kwargs): | ||
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt | ||
return self._generate_from_multiple_input(content, system_prompt=sys_prompt_arg, **kwargs) | ||
|
||
def generate(self, content, system_prompt=None, **kwargs): | ||
if isinstance(content, str): | ||
return self._prepare_generate_from_single_prompt(content, system_prompt=system_prompt, **kwargs) | ||
|
||
elif isinstance(content, list): | ||
has_multimodal_input = any(isinstance(item, bytes) for item in content) | ||
if (has_multimodal_input) and (not self.is_multimodal): | ||
raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.") | ||
|
||
return self._prepare_generate_from_multiple_input(content, system_prompt=system_prompt, **kwargs) | ||
|
||
def __call__(self, *args, **kwargs): | ||
pass | ||
|
||
class OpenAIEngine(EngineLM): | ||
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." | ||
|
||
def __init__( | ||
self, | ||
model_string: str = "gpt-3.5-turbo-0613", | ||
system_prompt: str = DEFAULT_SYSTEM_PROMPT, | ||
is_multimodal: bool = False, | ||
use_cache: bool = False, | ||
base_url: str = None): | ||
""" | ||
:param model_string: | ||
:param system_prompt: | ||
:param base_url: Used to support Ollama | ||
""" | ||
root = platformdirs.user_cache_dir("textgrad") | ||
cache_path = os.path.join(root, f"cache_openai_{model_string}.db") | ||
|
||
super().__init__(system_prompt=system_prompt, | ||
cache_path=cache_path, | ||
use_cache=use_cache, | ||
is_multimodal=is_multimodal, | ||
model_string=model_string) | ||
|
||
self.base_url = base_url | ||
|
||
if not base_url: | ||
if os.getenv("OPENAI_API_KEY") is None: | ||
raise ValueError( | ||
"Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.") | ||
|
||
self.client = OpenAI( | ||
api_key=os.getenv("OPENAI_API_KEY") | ||
) | ||
else: | ||
raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.") | ||
|
||
@cached | ||
def _generate_from_single_prompt( | ||
self, prompt: str, system_prompt: str = None, temperature=0, max_tokens=2000, top_p=0.99 | ||
): | ||
|
||
response = self.client.chat.completions.create( | ||
model=self.model_string, | ||
messages=[ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": prompt}, | ||
], | ||
frequency_penalty=0, | ||
presence_penalty=0, | ||
stop=None, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
) | ||
|
||
response = response.choices[0].message.content | ||
return response | ||
|
||
@cached | ||
def _generate_from_multiple_input( | ||
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 | ||
): | ||
formatted_content = self._format_content(content) | ||
|
||
response = self.client.chat.completions.create( | ||
model=self.model_string, | ||
messages=[ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": formatted_content}, | ||
], | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
) | ||
|
||
response_text = response.choices[0].message.content | ||
return response_text | ||
|
||
def __call__(self, prompt, **kwargs): | ||
return self.generate(prompt, **kwargs) | ||
|
||
def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]: | ||
"""Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API. | ||
""" | ||
formatted_content = [] | ||
for item in content: | ||
if isinstance(item, bytes): | ||
# For now, bytes are assumed to be images | ||
image_type = get_image_type_from_bytes(item) | ||
base64_image = base64.b64encode(item).decode('utf-8') | ||
formatted_content.append({ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/{image_type};base64,{base64_image}" | ||
} | ||
}) | ||
elif isinstance(item, str): | ||
formatted_content.append({ | ||
"type": "text", | ||
"text": item | ||
}) | ||
else: | ||
raise ValueError(f"Unsupported input type: {type(item)}") | ||
return formatted_content | ||
|
||
|
||
class OpenAICompatibleEngine(OpenAIEngine): | ||
""" | ||
This is the same as engine.openai.ChatOpenAI, but we pass in an external OpenAI client. | ||
""" | ||
|
||
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." | ||
client = None | ||
|
||
def __init__( | ||
self, | ||
client: OpenAI, | ||
model_string: str = "gpt-3.5-turbo-0613", | ||
system_prompt: str = DEFAULT_SYSTEM_PROMPT, | ||
is_multimodal: bool = False, | ||
use_cache: bool = False, | ||
base_url: str = None): | ||
""" | ||
:param client: an OpenAI client object. | ||
:param model_string: the model name, used for the cache file name and chat completion requests. | ||
:param system_prompt: the system prompt to use in chat completions. | ||
Example usage with lm-studio local server, but any client that follows the OpenAI API will work. | ||
```python | ||
from openai import OpenAI | ||
from textgrad.engine.local_model_openai_api import ChatExternalClient | ||
client = OpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio") | ||
engine = ChatExternalClient(client=client, model_string="your-model-name") | ||
print(engine.generate(max_tokens=40, prompt="What is the meaning of life?")) | ||
``` | ||
""" | ||
|
||
if os.getenv("OPENAI_API_KEY") is None: | ||
os.environ["OPENAI_API_KEY"] = client.api_key | ||
|
||
self.client = client | ||
|
||
super.__init__(model_string=model_string, | ||
system_prompt=system_prompt, | ||
is_multimodal=is_multimodal, | ||
use_cache=use_cache, | ||
base_url=base_url) |