diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 5608c50a8..a6c7cf5ec 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -1,8 +1,11 @@ import shlex +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.utils import update_settings_from_args from pr_agent.config_loader import get_settings from pr_agent.git_providers.utils import apply_repo_settings +from pr_agent.log import get_logger from pr_agent.tools.pr_add_docs import PRAddDocs from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_config import PRConfig @@ -38,8 +41,8 @@ commands = list(command2class.keys()) class PRAgent: - def __init__(self): - pass + def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()): + self.ai_handler = ai_handler async def handle_request(self, pr_url, request, notify=None) -> bool: # First, apply repo specific settings if exists @@ -61,13 +64,14 @@ async def handle_request(self, pr_url, request, notify=None) -> bool: if action == "answer": if notify: notify() - await PRReviewer(pr_url, is_answer=True, args=args).run() + await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run() elif action == "auto_review": - await PRReviewer(pr_url, is_auto=True, args=args).run() + await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run() elif action in command2class: if notify: notify() - await command2class[action](pr_url, args=args).run() + + await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run() else: return False return True diff --git a/pr_agent/algo/ai_handlers/base_ai_handler.py b/pr_agent/algo/ai_handlers/base_ai_handler.py new file mode 100644 index 000000000..c8473fb3e --- /dev/null +++ b/pr_agent/algo/ai_handlers/base_ai_handler.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod + +class BaseAiHandler(ABC): + """ + This class defines the interface for an AI handler to be used by the PR Agents. + """ + + @abstractmethod + def __init__(self): + pass + + @property + @abstractmethod + def deployment_id(self): + pass + + @abstractmethod + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + """ + This method should be implemented to return a chat completion from the AI model. + Args: + model (str): the name of the model to use for the chat completion + system (str): the system message string to use for the chat completion + user (str): the user message string to use for the chat completion + temperature (float): the temperature to use for the chat completion + """ + pass + diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py new file mode 100644 index 000000000..3e31bcb8b --- /dev/null +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -0,0 +1,49 @@ +try: + from langchain.chat_models import ChatOpenAI + from langchain.schema import SystemMessage, HumanMessage +except: # we don't enforce langchain as a dependency, so if it's not installed, just move on + pass + +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.config_loader import get_settings +from pr_agent.log import get_logger + +from openai.error import APIError, RateLimitError, Timeout, TryAgain +from retry import retry + +OPENAI_RETRIES = 5 + +class LangChainOpenAIHandler(BaseAiHandler): + def __init__(self): + # Initialize OpenAIHandler specific attributes here + try: + super().__init__() + self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key) + + except AttributeError as e: + raise ValueError("OpenAI key is required") from e + + @property + def chat(self): + return self._chat + + @property + def deployment_id(self): + """ + Returns the deployment ID for the OpenAI API. + """ + return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + try: + messages=[SystemMessage(content=system), HumanMessage(content=user)] + + # get a chat completion from the formatted messages + resp = self.chat(messages, model=model, temperature=temperature) + finish_reason="completed" + return resp.content, finish_reason + + except (Exception) as e: + get_logger().error("Unknown error during OpenAI inference: ", e) + raise e \ No newline at end of file diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py similarity index 97% rename from pr_agent/algo/ai_handler.py rename to pr_agent/algo/ai_handlers/litellm_ai_handler.py index 5b6a05f4e..7061ca797 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -6,13 +6,14 @@ from litellm import acompletion from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.config_loader import get_settings from pr_agent.log import get_logger OPENAI_RETRIES = 5 -class AiHandler: +class LiteLLMAIHandler(BaseAiHandler): """ This class handles interactions with the OpenAI API for chat completions. It initializes the API key and other settings from a configuration file, @@ -134,4 +135,4 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: usage = response.get("usage") get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, model=model, usage=usage) - return resp, finish_reason + return resp, finish_reason \ No newline at end of file diff --git a/pr_agent/algo/ai_handlers/openai_ai_handler.py b/pr_agent/algo/ai_handlers/openai_ai_handler.py new file mode 100644 index 000000000..3856f6f76 --- /dev/null +++ b/pr_agent/algo/ai_handlers/openai_ai_handler.py @@ -0,0 +1,67 @@ +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +import openai +from openai.error import APIError, RateLimitError, Timeout, TryAgain +from retry import retry + +from pr_agent.config_loader import get_settings +from pr_agent.log import get_logger + +OPENAI_RETRIES = 5 + + +class OpenAIHandler(BaseAiHandler): + def __init__(self): + # Initialize OpenAIHandler specific attributes here + try: + super().__init__() + openai.api_key = get_settings().openai.key + if get_settings().get("OPENAI.ORG", None): + openai.organization = get_settings().openai.org + if get_settings().get("OPENAI.API_TYPE", None): + if get_settings().openai.api_type == "azure": + self.azure = True + openai.azure_key = get_settings().openai.key + if get_settings().get("OPENAI.API_VERSION", None): + openai.api_version = get_settings().openai.api_version + if get_settings().get("OPENAI.API_BASE", None): + openai.api_base = get_settings().openai.api_base + + except AttributeError as e: + raise ValueError("OpenAI key is required") from e + @property + def deployment_id(self): + """ + Returns the deployment ID for the OpenAI API. + """ + return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + try: + deployment_id = self.deployment_id + get_logger().info("System: ", system) + get_logger().info("User: ", user) + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + + chat_completion = await openai.ChatCompletion.acreate( + model=model, + deployment_id=deployment_id, + messages=messages, + temperature=temperature, + ) + resp = chat_completion["choices"][0]['message']['content'] + finish_reason = chat_completion["choices"][0]["finish_reason"] + usage = chat_completion.get("usage") + get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, + model=model, usage=usage) + return resp, finish_reason + except (APIError, Timeout, TryAgain) as e: + get_logger().error("Error during OpenAI inference: ", e) + raise + except (RateLimitError) as e: + get_logger().error("Rate limit error during OpenAI inference: ", e) + raise + except (Exception) as e: + get_logger().error("Unknown error during OpenAI inference: ", e) + raise TryAgain from e \ No newline at end of file diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 6d0d3731b..9e1000423 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -447,4 +447,4 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str: return clipped_text except Exception as e: get_logger().warning(f"Failed to clip tokens: {e}") - return text + return text \ No newline at end of file diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index eec75b9cb..a729233d3 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -4,7 +4,8 @@ from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -15,14 +16,15 @@ class PRAddDocs: - def __init__(self, pr_url: str, cli_mode=False, args: list = None): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.patches_diff = None self.prediction = None self.cli_mode = cli_mode diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 3fc96d012..81e1ceabe 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -3,7 +3,8 @@ from typing import Dict, List from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -14,7 +15,8 @@ class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( @@ -31,7 +33,7 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None): else: num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.patches_diff = None self.prediction = None self.cli_mode = cli_mode diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 2e33a5108..4915c5b68 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -4,7 +4,8 @@ from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels @@ -15,7 +16,8 @@ class PRDescription: - def __init__(self, pr_url: str, args: list = None): + def __init__(self, pr_url: str, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. @@ -36,7 +38,7 @@ def __init__(self, pr_url: str, args: list = None): get_settings().pr_description.enable_semantic_files_types = False # Initialize the AI handler - self.ai_handler = AiHandler() + self.ai_handler = ai_handler # Initialize the variables dictionary self.vars = { diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py index fc90ed44f..25e80a55b 100644 --- a/pr_agent/tools/pr_generate_labels.py +++ b/pr_agent/tools/pr_generate_labels.py @@ -4,7 +4,8 @@ from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels @@ -15,7 +16,8 @@ class PRGenerateLabels: - def __init__(self, pr_url: str, args: list = None): + def __init__(self, pr_url: str, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): """ Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels corresponding to the PR using an AI model. @@ -31,7 +33,7 @@ def __init__(self, pr_url: str, args: list = None): self.pr_id = self.git_provider.get_pr_id() # Initialize the AI handler - self.ai_handler = AiHandler() + self.ai_handler = ai_handler # Initialize the variables dictionary self.vars = { diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index 059966e1e..a47d511be 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -2,7 +2,8 @@ from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,12 +13,13 @@ class PRInformationFromUser: - def __init__(self, pr_url: str, args: list = None): + def __init__(self, pr_url: str, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.vars = { "title": self.git_provider.pr.title, "branch": self.git_provider.get_pr_branch(), diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 7740fd4ae..5de3d7762 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -2,7 +2,8 @@ from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,13 +13,13 @@ class PRQuestions: - def __init__(self, pr_url: str, args=None): + def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()): question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.question_str = question_str self.vars = { "title": self.git_provider.pr.title, diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 5a6f720ac..24a40af31 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -7,7 +7,8 @@ from jinja2 import Environment, StrictUndefined from yaml import SafeLoader -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels @@ -22,13 +23,16 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None): + def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Args: pr_url (str): The URL of the pull request to be reviewed. is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. + is_auto (bool, optional): Indicates whether the review is being done in automatic mode. Defaults to False. + ai_handler (BaseAiHandler): The AI handler to be used for the review. Defaults to None. args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. """ self.parse_args(args) # -i command @@ -43,7 +47,7 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now") - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.patches_diff = None self.prediction = None diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index a5f24e0da..b8c6187f1 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -5,7 +5,8 @@ from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -17,7 +18,7 @@ class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None): + def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( @@ -25,7 +26,7 @@ def __init__(self, pr_url: str, cli_mode=False, args=None): ) self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes self._get_changlog_file() # self.changelog_file_str - self.ai_handler = AiHandler() + self.ai_handler = ai_handler self.patches_diff = None self.prediction = None self.cli_mode = cli_mode diff --git a/requirements.txt b/requirements.txt index 2f38da7bc..b293f3b38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ starlette-context==0.3.6 tiktoken==0.5.2 ujson==5.8.0 uvicorn==0.22.0 +# langchain==0.0.349 # uncomment this to support language LangChainOpenAIHandler