-
Notifications
You must be signed in to change notification settings - Fork 615
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #514 from brianpham93/abstract-BaseAiHandler
Abstract AiHandler to BaseAiHandler
- Loading branch information
Showing
15 changed files
with
198 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
class BaseAiHandler(ABC): | ||
""" | ||
This class defines the interface for an AI handler to be used by the PR Agents. | ||
""" | ||
|
||
@abstractmethod | ||
def __init__(self): | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def deployment_id(self): | ||
pass | ||
|
||
@abstractmethod | ||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): | ||
""" | ||
This method should be implemented to return a chat completion from the AI model. | ||
Args: | ||
model (str): the name of the model to use for the chat completion | ||
system (str): the system message string to use for the chat completion | ||
user (str): the user message string to use for the chat completion | ||
temperature (float): the temperature to use for the chat completion | ||
""" | ||
pass | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
try: | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.schema import SystemMessage, HumanMessage | ||
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on | ||
pass | ||
|
||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler | ||
from pr_agent.config_loader import get_settings | ||
from pr_agent.log import get_logger | ||
|
||
from openai.error import APIError, RateLimitError, Timeout, TryAgain | ||
from retry import retry | ||
|
||
OPENAI_RETRIES = 5 | ||
|
||
class LangChainOpenAIHandler(BaseAiHandler): | ||
def __init__(self): | ||
# Initialize OpenAIHandler specific attributes here | ||
try: | ||
super().__init__() | ||
self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key) | ||
|
||
except AttributeError as e: | ||
raise ValueError("OpenAI key is required") from e | ||
|
||
@property | ||
def chat(self): | ||
return self._chat | ||
|
||
@property | ||
def deployment_id(self): | ||
""" | ||
Returns the deployment ID for the OpenAI API. | ||
""" | ||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None) | ||
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), | ||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) | ||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): | ||
try: | ||
messages=[SystemMessage(content=system), HumanMessage(content=user)] | ||
|
||
# get a chat completion from the formatted messages | ||
resp = self.chat(messages, model=model, temperature=temperature) | ||
finish_reason="completed" | ||
return resp.content, finish_reason | ||
|
||
except (Exception) as e: | ||
get_logger().error("Unknown error during OpenAI inference: ", e) | ||
raise e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler | ||
import openai | ||
from openai.error import APIError, RateLimitError, Timeout, TryAgain | ||
from retry import retry | ||
|
||
from pr_agent.config_loader import get_settings | ||
from pr_agent.log import get_logger | ||
|
||
OPENAI_RETRIES = 5 | ||
|
||
|
||
class OpenAIHandler(BaseAiHandler): | ||
def __init__(self): | ||
# Initialize OpenAIHandler specific attributes here | ||
try: | ||
super().__init__() | ||
openai.api_key = get_settings().openai.key | ||
if get_settings().get("OPENAI.ORG", None): | ||
openai.organization = get_settings().openai.org | ||
if get_settings().get("OPENAI.API_TYPE", None): | ||
if get_settings().openai.api_type == "azure": | ||
self.azure = True | ||
openai.azure_key = get_settings().openai.key | ||
if get_settings().get("OPENAI.API_VERSION", None): | ||
openai.api_version = get_settings().openai.api_version | ||
if get_settings().get("OPENAI.API_BASE", None): | ||
openai.api_base = get_settings().openai.api_base | ||
|
||
except AttributeError as e: | ||
raise ValueError("OpenAI key is required") from e | ||
@property | ||
def deployment_id(self): | ||
""" | ||
Returns the deployment ID for the OpenAI API. | ||
""" | ||
return get_settings().get("OPENAI.DEPLOYMENT_ID", None) | ||
|
||
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), | ||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) | ||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): | ||
try: | ||
deployment_id = self.deployment_id | ||
get_logger().info("System: ", system) | ||
get_logger().info("User: ", user) | ||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] | ||
|
||
chat_completion = await openai.ChatCompletion.acreate( | ||
model=model, | ||
deployment_id=deployment_id, | ||
messages=messages, | ||
temperature=temperature, | ||
) | ||
resp = chat_completion["choices"][0]['message']['content'] | ||
finish_reason = chat_completion["choices"][0]["finish_reason"] | ||
usage = chat_completion.get("usage") | ||
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, | ||
model=model, usage=usage) | ||
return resp, finish_reason | ||
except (APIError, Timeout, TryAgain) as e: | ||
get_logger().error("Error during OpenAI inference: ", e) | ||
raise | ||
except (RateLimitError) as e: | ||
get_logger().error("Rate limit error during OpenAI inference: ", e) | ||
raise | ||
except (Exception) as e: | ||
get_logger().error("Unknown error during OpenAI inference: ", e) | ||
raise TryAgain from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.