diff --git a/tap_github/authenticator.py b/tap_github/authenticator.py index 7c9528d6..13b20bde 100644 --- a/tap_github/authenticator.py +++ b/tap_github/authenticator.py @@ -2,10 +2,11 @@ import logging import time +from copy import deepcopy from datetime import datetime from os import environ from random import choice, shuffle -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import jwt import requests @@ -13,8 +14,8 @@ from singer_sdk.streams import RESTStream -class TokenRateLimit: - """A class to store token rate limiting information.""" +class TokenManager: + """A class to store a token's attributes and state.""" DEFAULT_RATE_LIMIT = 5000 # The DEFAULT_RATE_LIMIT_BUFFER buffer serves two purposes: @@ -22,9 +23,15 @@ class TokenRateLimit: # - not consume all available calls when we rare using an org or user token. DEFAULT_RATE_LIMIT_BUFFER = 1000 - def __init__(self, token: str, rate_limit_buffer: Optional[int] = None): - """Init TokenRateLimit info.""" + def __init__( + self, + token: str, + rate_limit_buffer: Optional[int] = None, + logger: Optional[Any] = None, + ): + """Init TokenManager info.""" self.token = token + self.logger = logger self.rate_limit = self.DEFAULT_RATE_LIMIT self.rate_limit_remaining = self.DEFAULT_RATE_LIMIT self.rate_limit_reset: Optional[int] = None @@ -41,7 +48,28 @@ def update_rate_limit(self, response_headers: Any) -> None: self.rate_limit_reset = int(response_headers["X-RateLimit-Reset"]) self.rate_limit_used = int(response_headers["X-RateLimit-Used"]) - def is_valid(self) -> bool: + def is_valid_token(self) -> bool: + """Try making a request with the current token. If the request succeeds return True, else False.""" + try: + response = requests.get( + url="https://api.github.com/rate_limit", + headers={ + "Authorization": f"token {self.token}", + }, + ) + response.raise_for_status() + return True + except requests.exceptions.HTTPError: + msg = ( + f"A token was dismissed. " + f"{response.status_code} Client Error: " + f"{str(response.content)} (Reason: {response.reason})" + ) + if self.logger is not None: + self.logger.warning(msg) + return False + + def has_calls_remaining(self) -> bool: """Check if token is valid. Returns: @@ -113,25 +141,37 @@ def generate_app_access_token( class GitHubTokenAuthenticator(APIAuthenticatorBase): """Base class for offloading API auth.""" - def prepare_tokens(self) -> Dict[str, TokenRateLimit]: + def prepare_tokens(self) -> List[TokenManager]: # Save GitHub tokens - available_tokens: List[str] = [] + rate_limit_buffer = self._config.get("rate_limit_buffer", None) + + personal_tokens: Set[str] = set() if "auth_token" in self._config: - available_tokens = available_tokens + [self._config["auth_token"]] + personal_tokens.add(self._config["auth_token"]) if "additional_auth_tokens" in self._config: - available_tokens = available_tokens + self._config["additional_auth_tokens"] + personal_tokens = personal_tokens.union( + self._config["additional_auth_tokens"] + ) else: # Accept multiple tokens using environment variables GITHUB_TOKEN* - env_tokens = [ + env_tokens = { value for key, value in environ.items() if key.startswith("GITHUB_TOKEN") - ] + } if len(env_tokens) > 0: self.logger.info( f"Found {len(env_tokens)} 'GITHUB_TOKEN' environment variables for authentication." ) - available_tokens = env_tokens + personal_tokens = env_tokens + + token_managers: List[TokenManager] = [] + for token in personal_tokens: + token_manager = TokenManager( + token, rate_limit_buffer=rate_limit_buffer, logger=self.logger + ) + if token_manager.is_valid_token(): + token_managers.append(token_manager) # Parse App level private key and generate a token if "GITHUB_APP_PRIVATE_KEY" in environ.keys(): @@ -152,39 +192,17 @@ def prepare_tokens(self) -> Dict[str, TokenRateLimit]: app_token = generate_app_access_token( github_app_id, github_private_key, github_installation_id or None ) - available_tokens = available_tokens + [app_token] - - # Get rate_limit_buffer - rate_limit_buffer = self._config.get("rate_limit_buffer", None) - - # Dedup tokens and test them - filtered_tokens = [] - for token in list(set(available_tokens)): - try: - response = requests.get( - url="https://api.github.com/rate_limit", - headers={ - "Authorization": f"token {token}", - }, - ) - response.raise_for_status() - filtered_tokens.append(token) - except requests.exceptions.HTTPError: - msg = ( - f"A token was dismissed. " - f"{response.status_code} Client Error: " - f"{str(response.content)} (Reason: {response.reason})" + token_manager = TokenManager( + app_token, rate_limit_buffer=rate_limit_buffer, logger=self.logger ) - self.logger.warning(msg) + if token_manager.is_valid_token(): + token_managers.append(token_manager) - self.logger.info(f"Tap will run with {len(filtered_tokens)} auth tokens") + self.logger.info(f"Tap will run with {len(token_managers)} auth tokens") - # Create a dict of TokenRateLimit - # TODO - separate app_token and add logic to refresh the token - # using generate_app_access_token. - return { - token: TokenRateLimit(token, rate_limit_buffer) for token in filtered_tokens - } + # Create a dict of TokenManager + # TODO - separate app_token and add logic to refresh the token using generate_app_access_token. + return token_managers def __init__(self, stream: RESTStream) -> None: """Init authenticator. @@ -196,18 +214,21 @@ def __init__(self, stream: RESTStream) -> None: self.logger: logging.Logger = stream.logger self.tap_name: str = stream.tap_name self._config: Dict[str, Any] = dict(stream.config) - self.tokens_map = self.prepare_tokens() - self.active_token: Optional[TokenRateLimit] = ( - choice(list(self.tokens_map.values())) if len(self.tokens_map) else None + self.token_managers = self.prepare_tokens() + self.active_token: Optional[TokenManager] = ( + choice(self.token_managers) if self.token_managers else None ) def get_next_auth_token(self) -> None: - tokens_list = list(self.tokens_map.items()) current_token = self.active_token.token if self.active_token else "" - shuffle(tokens_list) - for _, token_rate_limit in tokens_list: - if token_rate_limit.is_valid() and current_token != token_rate_limit.token: - self.active_token = token_rate_limit + token_managers = deepcopy(self.token_managers) + shuffle(token_managers) + for token_manager in token_managers: + if ( + token_manager.has_calls_remaining() + and current_token != token_manager.token + ): + self.active_token = token_manager self.logger.info(f"Switching to fresh auth token") return @@ -219,7 +240,7 @@ def update_rate_limit( self, response_headers: requests.models.CaseInsensitiveDict ) -> None: # If no token or only one token is available, return early. - if len(self.tokens_map) <= 1 or self.active_token is None: + if len(self.token_managers) <= 1 or self.active_token is None: return self.active_token.update_rate_limit(response_headers) @@ -236,7 +257,7 @@ def auth_headers(self) -> Dict[str, str]: result = super().auth_headers if self.active_token: # Make sure that our token is still valid or update it. - if not self.active_token.is_valid(): + if not self.active_token.has_calls_remaining(): self.get_next_auth_token() result["Authorization"] = f"token {self.active_token.token}" else: diff --git a/tap_github/tests/test_authenticator.py b/tap_github/tests/test_authenticator.py new file mode 100644 index 00000000..c63853ca --- /dev/null +++ b/tap_github/tests/test_authenticator.py @@ -0,0 +1,114 @@ +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from tap_github.authenticator import TokenManager + + +class TestTokenManager: + + def test_default_rate_limits(self): + token_manager = TokenManager("mytoken", rate_limit_buffer=700) + + assert token_manager.rate_limit == 5000 + assert token_manager.rate_limit_remaining == 5000 + assert token_manager.rate_limit_reset is None + assert token_manager.rate_limit_used == 0 + assert token_manager.rate_limit_buffer == 700 + + token_manager_2 = TokenManager("mytoken") + assert token_manager_2.rate_limit_buffer == 1000 + + def test_update_rate_limit(self): + mock_response_headers = { + "X-RateLimit-Limit": "5000", + "X-RateLimit-Remaining": "4999", + "X-RateLimit-Reset": "1372700873", + "X-RateLimit-Used": "1", + } + + token_manager = TokenManager("mytoken") + token_manager.update_rate_limit(mock_response_headers) + + assert token_manager.rate_limit == 5000 + assert token_manager.rate_limit_remaining == 4999 + assert token_manager.rate_limit_reset == 1372700873 + assert token_manager.rate_limit_used == 1 + + def test_is_valid_token_successful(self): + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.raise_for_status.return_value = None + + token_manager = TokenManager("validtoken") + + assert token_manager.is_valid_token() + mock_get.assert_called_once_with( + url="https://api.github.com/rate_limit", + headers={"Authorization": "token validtoken"}, + ) + + def test_is_valid_token_failure(self): + with patch("requests.get") as mock_get: + # Setup for a failed request + mock_response = mock_get.return_value + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError() + mock_response.status_code = 401 + mock_response.content = b"Unauthorized Access" + mock_response.reason = "Unauthorized" + + token_manager = TokenManager("invalidtoken") + token_manager.logger = MagicMock() + + assert not token_manager.is_valid_token() + token_manager.logger.warning.assert_called_once() + assert "401" in token_manager.logger.warning.call_args[0][0] + + def test_has_calls_remaining_succeeds_if_token_never_used(self): + token_manager = TokenManager("mytoken") + assert token_manager.has_calls_remaining() + + def test_has_calls_remaining_succeeds_if_lots_remaining(self): + mock_response_headers = { + "X-RateLimit-Limit": "5000", + "X-RateLimit-Remaining": "4999", + "X-RateLimit-Reset": "1372700873", + "X-RateLimit-Used": "1", + } + + token_manager = TokenManager("mytoken") + token_manager.update_rate_limit(mock_response_headers) + + assert token_manager.has_calls_remaining() + + def test_has_calls_remaining_succeeds_if_reset_time_reached(self): + mock_response_headers = { + "X-RateLimit-Limit": "5000", + "X-RateLimit-Remaining": "1", + "X-RateLimit-Reset": "1372700873", + "X-RateLimit-Used": "4999", + } + + token_manager = TokenManager("mytoken", rate_limit_buffer=1000) + token_manager.update_rate_limit(mock_response_headers) + + assert token_manager.has_calls_remaining() + + def test_has_calls_remaining_fails_if_few_calls_remaining_and_reset_time_not_reached( + self, + ): + mock_response_headers = { + "X-RateLimit-Limit": "5000", + "X-RateLimit-Remaining": "1", + "X-RateLimit-Reset": str( + int((datetime.now() + timedelta(days=100)).timestamp()) + ), + "X-RateLimit-Used": "4999", + } + + token_manager = TokenManager("mytoken", rate_limit_buffer=1000) + token_manager.update_rate_limit(mock_response_headers) + + assert not token_manager.has_calls_remaining()