Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Authenticator refactoring (preparation for app token refreshing) #281

Merged
merged 5 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 74 additions & 53 deletions tap_github/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,36 @@

import logging
import time
from copy import deepcopy
from datetime import datetime
from os import environ
from random import choice, shuffle
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set

import jwt
import requests
from singer_sdk.authenticators import APIAuthenticatorBase
from singer_sdk.streams import RESTStream


class TokenRateLimit:
"""A class to store token rate limiting information."""
class TokenManager:
"""A class to store a token's attributes and state."""

DEFAULT_RATE_LIMIT = 5000
# The DEFAULT_RATE_LIMIT_BUFFER buffer serves two purposes:
# - keep some leeway and rotate tokens before erroring out on rate limit.
# - not consume all available calls when we rare using an org or user token.
DEFAULT_RATE_LIMIT_BUFFER = 1000

def __init__(self, token: str, rate_limit_buffer: Optional[int] = None):
"""Init TokenRateLimit info."""
def __init__(
self,
token: str,
rate_limit_buffer: Optional[int] = None,
logger: Optional[Any] = None,
):
"""Init TokenManager info."""
self.token = token
self.logger = logger
self.rate_limit = self.DEFAULT_RATE_LIMIT
self.rate_limit_remaining = self.DEFAULT_RATE_LIMIT
self.rate_limit_reset: Optional[int] = None
Expand All @@ -41,7 +48,28 @@ def update_rate_limit(self, response_headers: Any) -> None:
self.rate_limit_reset = int(response_headers["X-RateLimit-Reset"])
self.rate_limit_used = int(response_headers["X-RateLimit-Used"])

def is_valid(self) -> bool:
def is_valid_token(self) -> bool:
edgarrmondragon marked this conversation as resolved.
Show resolved Hide resolved
"""Try making a request with the current token. If the request succeeds return True, else False."""
try:
response = requests.get(
url="https://api.github.com/rate_limit",
headers={
"Authorization": f"token {self.token}",
},
)
response.raise_for_status()
return True
except requests.exceptions.HTTPError:
msg = (
f"A token was dismissed. "
f"{response.status_code} Client Error: "
f"{str(response.content)} (Reason: {response.reason})"
)
if self.logger is not None:
self.logger.warning(msg)
return False

def has_calls_remaining(self) -> bool:
"""Check if token is valid.

Returns:
Expand Down Expand Up @@ -113,25 +141,37 @@ def generate_app_access_token(
class GitHubTokenAuthenticator(APIAuthenticatorBase):
"""Base class for offloading API auth."""

def prepare_tokens(self) -> Dict[str, TokenRateLimit]:
def prepare_tokens(self) -> List[TokenManager]:
# Save GitHub tokens
available_tokens: List[str] = []
rate_limit_buffer = self._config.get("rate_limit_buffer", None)

personal_tokens: Set[str] = set()
edgarrmondragon marked this conversation as resolved.
Show resolved Hide resolved
if "auth_token" in self._config:
available_tokens = available_tokens + [self._config["auth_token"]]
personal_tokens.add(self._config["auth_token"])
if "additional_auth_tokens" in self._config:
available_tokens = available_tokens + self._config["additional_auth_tokens"]
personal_tokens = personal_tokens.union(
self._config["additional_auth_tokens"]
)
else:
# Accept multiple tokens using environment variables GITHUB_TOKEN*
env_tokens = [
env_tokens = {
value
for key, value in environ.items()
if key.startswith("GITHUB_TOKEN")
]
}
if len(env_tokens) > 0:
self.logger.info(
f"Found {len(env_tokens)} 'GITHUB_TOKEN' environment variables for authentication."
)
available_tokens = env_tokens
personal_tokens = env_tokens

token_managers: List[TokenManager] = []
edgarrmondragon marked this conversation as resolved.
Show resolved Hide resolved
for token in personal_tokens:
token_manager = TokenManager(
token, rate_limit_buffer=rate_limit_buffer, logger=self.logger
)
if token_manager.is_valid_token():
token_managers.append(token_manager)

# Parse App level private key and generate a token
if "GITHUB_APP_PRIVATE_KEY" in environ.keys():
Expand All @@ -152,39 +192,17 @@ def prepare_tokens(self) -> Dict[str, TokenRateLimit]:
app_token = generate_app_access_token(
github_app_id, github_private_key, github_installation_id or None
)
available_tokens = available_tokens + [app_token]

# Get rate_limit_buffer
rate_limit_buffer = self._config.get("rate_limit_buffer", None)

# Dedup tokens and test them
filtered_tokens = []
for token in list(set(available_tokens)):
try:
response = requests.get(
url="https://api.github.com/rate_limit",
headers={
"Authorization": f"token {token}",
},
)
response.raise_for_status()
filtered_tokens.append(token)
except requests.exceptions.HTTPError:
msg = (
f"A token was dismissed. "
f"{response.status_code} Client Error: "
f"{str(response.content)} (Reason: {response.reason})"
token_manager = TokenManager(
app_token, rate_limit_buffer=rate_limit_buffer, logger=self.logger
)
self.logger.warning(msg)
if token_manager.is_valid_token():
token_managers.append(token_manager)

self.logger.info(f"Tap will run with {len(filtered_tokens)} auth tokens")
self.logger.info(f"Tap will run with {len(token_managers)} auth tokens")

# Create a dict of TokenRateLimit
# TODO - separate app_token and add logic to refresh the token
# using generate_app_access_token.
return {
token: TokenRateLimit(token, rate_limit_buffer) for token in filtered_tokens
}
# Create a dict of TokenManager
# TODO - separate app_token and add logic to refresh the token using generate_app_access_token.
return token_managers

def __init__(self, stream: RESTStream) -> None:
"""Init authenticator.
Expand All @@ -196,18 +214,21 @@ def __init__(self, stream: RESTStream) -> None:
self.logger: logging.Logger = stream.logger
self.tap_name: str = stream.tap_name
self._config: Dict[str, Any] = dict(stream.config)
self.tokens_map = self.prepare_tokens()
self.active_token: Optional[TokenRateLimit] = (
choice(list(self.tokens_map.values())) if len(self.tokens_map) else None
self.token_managers = self.prepare_tokens()
self.active_token: Optional[TokenManager] = (
choice(self.token_managers) if self.token_managers else None
)

def get_next_auth_token(self) -> None:
tokens_list = list(self.tokens_map.items())
current_token = self.active_token.token if self.active_token else ""
shuffle(tokens_list)
for _, token_rate_limit in tokens_list:
if token_rate_limit.is_valid() and current_token != token_rate_limit.token:
self.active_token = token_rate_limit
token_managers = deepcopy(self.token_managers)
shuffle(token_managers)
edgarrmondragon marked this conversation as resolved.
Show resolved Hide resolved
edgarrmondragon marked this conversation as resolved.
Show resolved Hide resolved
for token_manager in token_managers:
if (
token_manager.has_calls_remaining()
and current_token != token_manager.token
):
self.active_token = token_manager
self.logger.info(f"Switching to fresh auth token")
return

Expand All @@ -219,7 +240,7 @@ def update_rate_limit(
self, response_headers: requests.models.CaseInsensitiveDict
) -> None:
# If no token or only one token is available, return early.
if len(self.tokens_map) <= 1 or self.active_token is None:
if len(self.token_managers) <= 1 or self.active_token is None:
return

self.active_token.update_rate_limit(response_headers)
Expand All @@ -236,7 +257,7 @@ def auth_headers(self) -> Dict[str, str]:
result = super().auth_headers
if self.active_token:
# Make sure that our token is still valid or update it.
if not self.active_token.is_valid():
if not self.active_token.has_calls_remaining():
self.get_next_auth_token()
result["Authorization"] = f"token {self.active_token.token}"
else:
Expand Down
114 changes: 114 additions & 0 deletions tap_github/tests/test_authenticator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch

import pytest
import requests

from tap_github.authenticator import TokenManager


class TestTokenManager:

def test_default_rate_limits(self):
token_manager = TokenManager("mytoken", rate_limit_buffer=700)

assert token_manager.rate_limit == 5000
assert token_manager.rate_limit_remaining == 5000
assert token_manager.rate_limit_reset is None
assert token_manager.rate_limit_used == 0
assert token_manager.rate_limit_buffer == 700

token_manager_2 = TokenManager("mytoken")
assert token_manager_2.rate_limit_buffer == 1000

def test_update_rate_limit(self):
mock_response_headers = {
"X-RateLimit-Limit": "5000",
"X-RateLimit-Remaining": "4999",
"X-RateLimit-Reset": "1372700873",
"X-RateLimit-Used": "1",
}

token_manager = TokenManager("mytoken")
token_manager.update_rate_limit(mock_response_headers)

assert token_manager.rate_limit == 5000
assert token_manager.rate_limit_remaining == 4999
assert token_manager.rate_limit_reset == 1372700873
assert token_manager.rate_limit_used == 1

def test_is_valid_token_successful(self):
with patch("requests.get") as mock_get:
mock_response = mock_get.return_value
mock_response.raise_for_status.return_value = None

token_manager = TokenManager("validtoken")

assert token_manager.is_valid_token()
mock_get.assert_called_once_with(
url="https://api.github.com/rate_limit",
headers={"Authorization": "token validtoken"},
)

def test_is_valid_token_failure(self):
with patch("requests.get") as mock_get:
# Setup for a failed request
mock_response = mock_get.return_value
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError()
mock_response.status_code = 401
mock_response.content = b"Unauthorized Access"
mock_response.reason = "Unauthorized"

token_manager = TokenManager("invalidtoken")
token_manager.logger = MagicMock()

assert not token_manager.is_valid_token()
token_manager.logger.warning.assert_called_once()
assert "401" in token_manager.logger.warning.call_args[0][0]

def test_has_calls_remaining_succeeds_if_token_never_used(self):
token_manager = TokenManager("mytoken")
assert token_manager.has_calls_remaining()

def test_has_calls_remaining_succeeds_if_lots_remaining(self):
mock_response_headers = {
"X-RateLimit-Limit": "5000",
"X-RateLimit-Remaining": "4999",
"X-RateLimit-Reset": "1372700873",
"X-RateLimit-Used": "1",
}

token_manager = TokenManager("mytoken")
token_manager.update_rate_limit(mock_response_headers)

assert token_manager.has_calls_remaining()

def test_has_calls_remaining_succeeds_if_reset_time_reached(self):
mock_response_headers = {
"X-RateLimit-Limit": "5000",
"X-RateLimit-Remaining": "1",
"X-RateLimit-Reset": "1372700873",
"X-RateLimit-Used": "4999",
}

token_manager = TokenManager("mytoken", rate_limit_buffer=1000)
token_manager.update_rate_limit(mock_response_headers)

assert token_manager.has_calls_remaining()

def test_has_calls_remaining_fails_if_few_calls_remaining_and_reset_time_not_reached(
self,
):
mock_response_headers = {
"X-RateLimit-Limit": "5000",
"X-RateLimit-Remaining": "1",
"X-RateLimit-Reset": str(
int((datetime.now() + timedelta(days=100)).timestamp())
),
"X-RateLimit-Used": "4999",
}

token_manager = TokenManager("mytoken", rate_limit_buffer=1000)
token_manager.update_rate_limit(mock_response_headers)

assert not token_manager.has_calls_remaining()