-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding groq support * made chatgroq subclass chatopenai * remove breakpoint
- Loading branch information
1 parent
5cd7cc0
commit 4a5063e
Showing
2 changed files
with
52 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
try: | ||
from groq import Groq | ||
except ImportError: | ||
raise ImportError("If you'd like to use Groq models, please install the groq package by running `pip install groq`, and add 'GROQ_API_KEY' to your environment variables.") | ||
|
||
import os | ||
import json | ||
import base64 | ||
import platformdirs | ||
from tenacity import ( | ||
retry, | ||
stop_after_attempt, | ||
wait_random_exponential, | ||
) | ||
from typing import List, Union | ||
|
||
from .base import EngineLM, CachedEngine | ||
from .engine_utils import get_image_type_from_bytes | ||
from .openai import ChatOpenAI | ||
|
||
|
||
class ChatGroq(ChatOpenAI): | ||
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant1." | ||
|
||
def __init__( | ||
self, | ||
model_string: str="groq-llama3-70b-8192", | ||
system_prompt: str=DEFAULT_SYSTEM_PROMPT, | ||
**kwargs): | ||
""" | ||
:param model_string: | ||
:param system_prompt: | ||
:param base_url: Used to support Ollama | ||
""" | ||
root = platformdirs.user_cache_dir("textgrad") | ||
cache_path = os.path.join(root, f"cache_groq_{model_string}.db") | ||
CachedEngine.__init__(self, cache_path=cache_path) | ||
|
||
if os.getenv("GROQ_API_KEY") is None: | ||
raise ValueError("Please set the GROQ_API_KEY environment variable if you'd like to use Groq models.") | ||
self.client = Groq( | ||
api_key=os.getenv("GROQ_API_KEY") | ||
) | ||
|
||
self.model_string = model_string | ||
self.system_prompt = system_prompt | ||
assert isinstance(self.system_prompt, str) | ||
self.is_multimodal = False |