diff --git a/github_tracker_bot/ai_decide_commits.py b/github_tracker_bot/ai_decide_commits.py index 2d7a15a..328b256 100644 --- a/github_tracker_bot/ai_decide_commits.py +++ b/github_tracker_bot/ai_decide_commits.py @@ -1,20 +1,24 @@ import os import sys import json - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -import config -from typing import TypedDict, List +import aiohttp +import asyncio +import time +from typing import Optional, List, TypedDict from datetime import datetime +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) +import config import log_config import github_tracker_bot.prompts as prompts - from openai import AuthenticationError, NotFoundError, OpenAI, OpenAIError logger = log_config.get_logger(__name__) - client = OpenAI(api_key=config.OPENAI_API_KEY) @@ -37,24 +41,38 @@ def validate_date_format(date_str: str) -> bool: return False +retry_conditions = ( + retry_if_exception_type(AuthenticationError) + | retry_if_exception_type(OpenAIError) + | retry_if_exception_type(aiohttp.ClientError) + | retry_if_exception_type(asyncio.TimeoutError) + | retry_if_exception_type(aiohttp.ClientConnectorError) +) + + +@retry( + wait=wait_exponential(multiplier=2, min=5, max=60), + stop=stop_after_attempt(8), + retry=retry_conditions, +) async def decide_daily_commits( date: str, data_array: List[CommitData], seed: int = 42 -): +) -> Optional[str]: if not validate_date_format(date): raise ValueError("Incorrect date format, should be YYYY-MM-DD") - try: - commit_data = next((data for data in data_array), None) - if not commit_data: - logger.error("Commit data or diff file is empty") - return False + if not data_array: + logger.error("Commit data array is empty") + return None - message = prompts.process_message(date, data_array) - if not message: - logger.error("After processing commit") - return False + message = prompts.process_message(date, data_array) + if not message: + logger.error("Message processing failed") + return None - completion = client.chat.completions.create( + try: + completion = await asyncio.to_thread( + client.chat.completions.create, model="gpt-4o", response_format={"type": "json_object"}, messages=[ @@ -67,12 +85,29 @@ async def decide_daily_commits( seed=seed, temperature=0.1, ) - return completion.choices[0].message.content + except NotFoundError: + logger.error("404 Not Found Error.") + return None + + except aiohttp.ClientResponseError as e: + if e.status == 403: + reset_time = e.headers.get("X-RateLimit-Reset") + try: + sleep_time = ( + max(int(reset_time) - int(time.time()) + 1, 1) if reset_time else 60 + ) + except ValueError: + sleep_time = 60 + logger.warning(f"Rate limit exceeded. Sleeping for {sleep_time} seconds.") + await asyncio.sleep(sleep_time) + raise aiohttp.ClientError("Rate limit exceeded, retrying...") + except OpenAIError as e: - logger.error(f"OpenAI API call failed with error: {e}") + logger.error(f"OpenAI API Error: {e}") + return None except Exception as e: - logger.error(f"An unexpected error occurred: {e}") - return False + logger.error(f"Unexpected error: {e}") + return None diff --git a/github_tracker_bot/mongo_data_handler.py b/github_tracker_bot/mongo_data_handler.py index 7aefdfa..f020e2d 100644 --- a/github_tracker_bot/mongo_data_handler.py +++ b/github_tracker_bot/mongo_data_handler.py @@ -665,4 +665,4 @@ def delete_ai_decisions_and_clean_users( logger.error( f"Failed to delete ai_decisions and clean users between {since_date} and {until_date}: {e}" ) - raise \ No newline at end of file + raise diff --git a/tests/test_process_commits.py b/tests/test_process_commits.py index 9ad371d..ff06528 100644 --- a/tests/test_process_commits.py +++ b/tests/test_process_commits.py @@ -112,7 +112,7 @@ async def test_handle_403_api_rate_limit(self, mock_get, mock_time, mock_sleep): # 1. Once for the rate limit (expected_sleep_time) # 2. Once for the tenacity retry (fixed 2 seconds) self.assertEqual(mock_sleep.call_count, 2) - mock_sleep.assert_has_calls([call(expected_sleep_time), call(2.0)]) + mock_sleep.assert_has_calls([call(expected_sleep_time), call(5.0)]) # Ensure that the second call to `aiohttp.get` was successful self.assertEqual(mock_get.call_count, 2)