Skip to content

Commit

Permalink
Add throttling when rate limiting
Browse files Browse the repository at this point in the history
Change-type: minor
  • Loading branch information
otaviojacobi committed Oct 31, 2023
1 parent c35338c commit 5cac810
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 116 deletions.
190 changes: 99 additions & 91 deletions DOCUMENTATION.md

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions balena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"timeout": str(30 * 1000), # request timeout, 30s
"request_limit": str(300), # the number of requests per request_limit_interval that the SDK should respect, defaults to unlimited.
"request_limit_interval": str(60), # the timespan that the request_limit should apply to in seconds, defaults to 60s (1 minute).
"retry_rate_limited_request": False, # awaits and retry once a request is rate limited (429)
})
```
Expand Down Expand Up @@ -65,6 +66,13 @@
})
```
By default the SDK will throw once a request is Rate limited by the API (with a 429 status code).
A 429 request will contain a header called "retry-after" which informs how long the client should wait before trying a new request.
If you would like the SDK to use this header and wait and automatically retry the request, just do:
```python
balena = Balena({"retry_rate_limited_request": True})
```
If you feel something is missing, not clear or could be improved, [please don't
hesitate to open an issue in GitHub](https://github.com/balena-io/balena-sdk-python/issues), we'll be happy to help.
Expand Down
2 changes: 1 addition & 1 deletion balena/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def get_token(self) -> Optional[str]:
>>> balena.auth.get_token()
"""
try:
return self.__settings.get(TOKEN_KEY)
return cast(str, self.__settings.get(TOKEN_KEY))
except exceptions.InvalidOption:
return None

Expand Down
4 changes: 2 additions & 2 deletions balena/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from threading import Thread
from urllib.parse import urljoin
from typing import Union, Optional, Literal, Callable, TypedDict, Any, List
from typing import Union, Optional, Literal, Callable, TypedDict, Any, List, cast

from twisted.internet import reactor
from twisted.internet.protocol import Protocol
Expand Down Expand Up @@ -89,7 +89,7 @@ def add(
if count:
query = f"stream=1&count={count}"

url = urljoin(self.__settings.get("api_endpoint"), f"/device/v2/{uuid}/logs?{query}")
url = urljoin(cast(str, self.__settings.get("api_endpoint")), f"/device/v2/{uuid}/logs?{query}")
headers = Headers({"Authorization": [f"Bearer {get_token(self.__settings)}"]})

agent = Agent(reactor)
Expand Down
4 changes: 2 additions & 2 deletions balena/models/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from math import isinf
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional, Union, cast
from urllib.parse import urljoin

from .. import exceptions
Expand Down Expand Up @@ -136,7 +136,7 @@ def get_dashboard_url(self, app_id: int) -> str:
raise exceptions.InvalidParameter("app_id", app_id)

return urljoin(
self.__settings.get("api_endpoint").replace("api", "dashboard"),
cast(str, self.__settings.get("api_endpoint")).replace("api", "dashboard"),
f"/apps/{app_id}",
)

Expand Down
4 changes: 2 additions & 2 deletions balena/models/device.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import binascii
import datetime
import os
from typing import Any, Callable, List, Optional, TypedDict, Union
from typing import Any, Callable, List, Optional, TypedDict, Union, cast
from urllib.parse import urljoin

from deprecated import deprecated
Expand Down Expand Up @@ -266,7 +266,7 @@ def get_dashboard_url(self, uuid: str):

if not isinstance(uuid, str) or len(uuid) == 0:
raise ValueError("Device UUID must be a non empty string")
dashboard_url = self.__settings.get("api_endpoint").replace("api", "dashboard")
dashboard_url = cast(str, self.__settings.get("api_endpoint")).replace("api", "dashboard")
return urljoin(dashboard_url, f"/devices/{uuid}/summary")

def get_all(self, options: AnyObject = {}) -> List[TypeDevice]:
Expand Down
17 changes: 14 additions & 3 deletions balena/pine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from typing import Any, Optional
from typing import Any, Optional, cast
from urllib.parse import urljoin
from ratelimit import limits, sleep_and_retry
from time import sleep

import requests
from pine_client import PinejsClientCore
Expand All @@ -20,8 +21,8 @@ def __init__(self, settings: Settings, sdk_version: str, params: Optional[Params
self.__settings = settings
self.__sdk_version = sdk_version

api_url = settings.get("api_endpoint")
api_version = settings.get("api_version")
api_url = cast(str, settings.get("api_endpoint"))
api_version = cast(str, settings.get("api_version"))

try:
calls = int(self.__settings.get("request_limit"))
Expand Down Expand Up @@ -55,4 +56,14 @@ def __base_request(self, method: str, url: str, body: Optional[Any] = None) -> A
except Exception:
return req.content.decode()
else:
retry_after = req.headers.get("retry-after")
if (
self.__settings.get("retry_rate_limited_request") is True
and req.status_code == 429
and retry_after is not None
and retry_after.isdigit()
):
sleep(int(retry_after))
return self.__base_request(method, url, body)

raise RequestError(body=req.content.decode(), status_code=req.status_code)
36 changes: 23 additions & 13 deletions balena/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SettingsConfig(TypedDict, total=False):
timeout: str
request_limit: str
request_limit_interval: str
retry_rate_limited_request: bool


class SettingsProviderInterface(ABC):
Expand All @@ -29,15 +30,15 @@ def has(self, key: str) -> bool:
pass

@abstractmethod
def get(self, key: str) -> str:
def get(self, key: str) -> Union[str, bool]:
pass

@abstractmethod
def get_all(self) -> Dict[str, str]:
def get_all(self) -> Dict[str, Union[str, bool]]:
pass

@abstractmethod
def set(self, key: str, value: str) -> None:
def set(self, key: str, value: Union[str, bool]) -> None:
pass

@abstractmethod
Expand All @@ -58,6 +59,7 @@ def remove(self, key: str) -> bool:
"timeout": str(30 * 1000),
# requests timeout: 60 seconds in seconds
"request_limit_interval": str(60),
"retry_rate_limited_request": False,
}


Expand Down Expand Up @@ -88,6 +90,7 @@ class FileStorageSettingsProvider(SettingsProviderInterface):
"cache_directory",
"timeout",
"device_actions_endpoint_version",
"retry_rate_limited_request",
]
)

Expand Down Expand Up @@ -146,7 +149,10 @@ def __write_settings(self, default=None):
config = configparser.ConfigParser()
config.add_section(self.CONFIG_SECTION)
for key in self._setting:
config.set(self.CONFIG_SECTION, key, self._setting[key])
value = self._setting[key]
if isinstance(value, bool):
value = "true" if value else "false"
config.set(self.CONFIG_SECTION, key, value)
if not Path.isdir(self._setting["data_directory"]):
os.makedirs(self._setting["data_directory"])
with open(Path.join(self._setting["data_directory"], self.CONFIG_FILENAME), "w") as config_file:
Expand All @@ -160,6 +166,10 @@ def __read_settings(self):
for option in options:
try:
config_data[option] = config_reader.get(self.CONFIG_SECTION, option)
if config_data[option] == "true":
config_data[option] = True
if config_data[option] == "false":
config_data[option] = False
except Exception:
config_data[option] = None
self._setting = config_data
Expand All @@ -170,18 +180,18 @@ def has(self, key: str) -> bool:
return True
return False

def get(self, key: str) -> str:
def get(self, key: str) -> Union[str, bool]:
try:
self.__read_settings()
return self._setting[key]
except KeyError:
raise exceptions.InvalidOption(key)

def get_all(self) -> Dict[str, str]:
def get_all(self) -> Dict[str, Union[str, bool]]:
self.__read_settings()
return self._setting

def set(self, key: str, value: str) -> None:
def set(self, key: str, value: Union[str, bool]) -> None:
self._setting[key] = str(value)
self.__write_settings()

Expand Down Expand Up @@ -211,16 +221,16 @@ def has(self, key: str) -> bool:
return True
return False

def get(self, key: str) -> str:
def get(self, key: str) -> Union[str, bool]:
try:
return self._settings[key]
except KeyError:
raise exceptions.InvalidOption(key)

def get_all(self) -> Dict[str, str]:
def get_all(self) -> Dict[str, Union[str, bool]]:
return self._settings

def set(self, key: str, value: str) -> None:
def set(self, key: str, value: Union[str, bool]) -> None:
self._settings[key] = str(value)

def remove(self, key: str) -> bool:
Expand Down Expand Up @@ -259,7 +269,7 @@ def has(self, key: str) -> bool:
"""
return self.__settings_provider.has(key)

def get(self, key: str) -> str:
def get(self, key: str) -> Union[str, bool]:
"""
Get a setting value.
Expand All @@ -277,7 +287,7 @@ def get(self, key: str) -> str:
"""
return self.__settings_provider.get(key)

def get_all(self) -> Dict[str, str]:
def get_all(self) -> Dict[str, Union[str, bool]]:
"""
Get all settings.
Expand All @@ -289,7 +299,7 @@ def get_all(self) -> Dict[str, str]:
"""
return self.__settings_provider.get_all()

def set(self, key: str, value: str) -> None:
def set(self, key: str, value: Union[str, bool]) -> None:
"""
Set value for a setting.
Expand Down
5 changes: 3 additions & 2 deletions balena/twofactor_auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from urllib.parse import parse_qs
from typing import cast
import jwt

from . import exceptions
Expand Down Expand Up @@ -28,7 +29,7 @@ def is_enabled(self) -> bool:
>>> balena.twofactor_auth.is_enabled()
"""
try:
token = self.__settings.get(TOKEN_KEY)
token = cast(str, self.__settings.get(TOKEN_KEY))
token_data = jwt.decode(token, algorithms=["HS256"], options={"verify_signature": False})
return "twoFactorRequired" in token_data
except jwt.InvalidTokenError:
Expand All @@ -46,7 +47,7 @@ def is_passed(self) -> bool:
>>> balena.twofactor_auth.is_passed()
"""
try:
token = self.__settings.get(TOKEN_KEY)
token = cast(str, self.__settings.get(TOKEN_KEY))
token_data = jwt.decode(token, algorithms=["HS256"], options={"verify_signature": False})
if "twoFactorRequired" in token_data:
return not token_data["twoFactorRequired"]
Expand Down

0 comments on commit 5cac810

Please sign in to comment.