From 0f4b69221f13ba355253cf27df7b48c0f7fe3685 Mon Sep 17 00:00:00 2001 From: Evan Blaudy Date: Fri, 13 Dec 2024 15:35:08 +0100 Subject: [PATCH] [client] better handling for refresh_token --- gazu/__init__.py | 4 +- gazu/client.py | 246 ++++++++++++++++++++++++++++++++++--------- gazu/events.py | 10 +- setup.cfg | 2 - tests/test_client.py | 107 ++++--------------- 5 files changed, 224 insertions(+), 145 deletions(-) diff --git a/gazu/__init__.py b/gazu/__init__.py index c7ea0a7..1bd44b3 100644 --- a/gazu/__init__.py +++ b/gazu/__init__.py @@ -87,8 +87,8 @@ def log_out(client=raw.default_client): return tokens -def refresh_token(client=raw.default_client): - return client.refresh_authentication_tokens() +def refresh_access_token(client=raw.default_client): + return client.refresh_access_token() def get_event_host(client=raw.default_client): diff --git a/gazu/client.py b/gazu/client.py index 51379e7..2e37fa4 100644 --- a/gazu/client.py +++ b/gazu/client.py @@ -3,8 +3,6 @@ import json import shutil import os -import jwt -from datetime import datetime from .encoder import CustomJSONEncoder @@ -38,17 +36,25 @@ def __init__( host, ssl_verify=True, cert=None, - automatic_refresh_token=False, + use_refresh_token=True, callback_not_authenticated=None, + tokens={"access_token": None, "refresh_token": None}, + access_token=None, + refresh_token=None, ): - self.tokens = {"access_token": "", "refresh_token": ""} + self.tokens = tokens + if access_token: + self.access_token = access_token + if refresh_token: + self.refresh_token = refresh_token + self.use_refresh_token = use_refresh_token + self.callback_not_authenticated = callback_not_authenticated + self.session = requests.Session() self.session.verify = ssl_verify self.session.cert = cert self.host = host self.event_host = host - self.automatic_refresh_token = automatic_refresh_token - self.callback_not_authenticated = callback_not_authenticated @property def access_token(self): @@ -62,48 +68,42 @@ def access_token(self, token): def refresh_token(self): return self.tokens.get("refresh_token", None) - @property - def access_token_has_expired(self): - """ Returns: Whether this client's access token needs to be refreshed. """ - # Python 2 is too deprecated to support this feature. (Lack of datetime.timestamp()) - if sys.version_info.major == 2: - return False - - if not self.access_token: - # No access token is present, refresh only when able with a refresh token. - return True if self.refresh_token else False + @refresh_token.setter + def refresh_token(self, token): + self.tokens["refresh_token"] = token - # Decode the access token. - decoded_token = jwt.decode(jwt=self.access_token, options={'verify_signature': False}) - expiration_datetime = datetime.fromtimestamp(decoded_token['exp']) + def refresh_access_token(self): + """ + Refresh access tokens for this client. - # NOTE - Due to shenanigans caused by possible timezone differences - # between server and client, the access token is considered - # stale when less than 24 hours remain to its expiration. - - return bool((expiration_datetime - datetime.now()).days < 1) - - def refresh_authentication_tokens(self): - """ Refresh access tokens for this client.""" + Returns: + dict: The new access token. + """ response = self.session.get( get_full_url("auth/refresh-token", client=self), - headers={"User-Agent": "CGWire Gazu " + __version__, - "Authorization": "Bearer " + self.refresh_token}) + headers={ + "User-Agent": "CGWire Gazu " + __version__, + "Authorization": "Bearer " + self.refresh_token, + }, + ) check_status(response, "auth/refresh-token") tokens = response.json() self.access_token = tokens["access_token"] + self.refresh_token = None return tokens def make_auth_header(self): - """ Returns: Headers required to authenticate. """ + """ + Make headers required to authenticate. + + Returns: + dict: Headers required to authenticate. + """ headers = {"User-Agent": "CGWire Gazu " + __version__} if self.access_token: - if self.access_token_has_expired and self.automatic_refresh_token: - self.refresh_authentication_tokens() - headers["Authorization"] = "Bearer " + self.access_token return headers @@ -113,15 +113,30 @@ def create_client( host, ssl_verify=True, cert=None, - automatic_refresh_token=False, + use_refresh_token=False, callback_not_authenticated=None, + **kwargs ): + """ + Create a client with given parameters. + + Args: + host (str): The host to use for requests. + ssl_verify (bool): Whether to verify SSL certificates. + cert (str): Path to a client certificate. + use_refresh_token (bool): Whether to automatically refresh tokens. + callback_not_authenticated (function): Function to call when not authenticated. + + Returns: + KitsuClient: The created client. + """ return KitsuClient( host, ssl_verify, cert=cert, - automatic_refresh_token=automatic_refresh_token, + use_refresh_token=use_refresh_token, callback_not_authenticated=callback_not_authenticated, + **kwargs ) @@ -141,8 +156,13 @@ def create_client( def host_is_up(client=default_client): """ + Check if the host is up. + + Args: + client (KitsuClient): The client to use for the request. + Returns: - True if the host is up. + bool: True if the host is up. """ try: response = client.session.head(client.host) @@ -154,8 +174,12 @@ def host_is_up(client=default_client): def host_is_valid(client=default_client): """ Check if the host is valid by simulating a fake login. + + Args: + client (KitsuClient): The client to use for the request. + Returns: - True if the host is valid. + bool: True if the host is valid. """ if not host_is_up(client): return False @@ -167,14 +191,23 @@ def host_is_valid(client=default_client): def get_host(client=default_client): """ + Get client.host. + + Args: + client (KitsuClient): The client to use for the request. + Returns: - Host on which requests are sent. + str: The host of the client. """ return client.host def get_api_url_from_host(client=default_client): """ + Get the API url from the host. + + Args: + client (KitsuClient): The client to use for the request. Returns: Zou url, retrieved from host. """ @@ -183,8 +216,14 @@ def get_api_url_from_host(client=default_client): def set_host(new_host, client=default_client): """ + Set the host for the client. + + Args: + new_host (str): The new host to set. + client (KitsuClient): The client to use for the request. + Returns: - Set currently configured host on which requests are sent. + str: The new host. """ client.host = new_host return client.host @@ -192,16 +231,27 @@ def set_host(new_host, client=default_client): def get_event_host(client=default_client): """ + Get the host on which listening for events. + + Args: + client (KitsuClient): The client to use for the request. + Returns: - Host on which listening for events. + str: The event host. """ return client.event_host or client.host def set_event_host(new_host, client=default_client): """ + Set the host on which listening for events. + + Args: + new_host (str): The new host to set. + client (KitsuClient): The client to use for the request. + Returns: - Set currently configured host on which listening for events. + str: The new event host. """ client.event_host = new_host return client.event_host @@ -213,12 +263,25 @@ def set_tokens(new_tokens, client=default_client): Args: new_tokens (dict): Tokens to use for authentication. + client (KitsuClient): The client to use for the request. + + Returns: + dict: The stored tokens. """ client.tokens = new_tokens return client.tokens def make_auth_header(client=default_client): + """ + Make headers required to authenticate. + + Args: + client (KitsuClient): The client to use for the request. + + Returns: + dict: Headers required to authenticate. + """ return client.make_auth_header() @@ -229,12 +292,17 @@ def url_path_join(*items): Args: items (list): Path elements + + Returns: + str: The joined path. """ return "/".join([item.lstrip("/").rstrip("/") for item in items]) def get_full_url(path, client=default_client): """ + Join host url with given path. + Args: path (str): The path to integrate to host url. @@ -272,6 +340,12 @@ def get(path, json_response=True, params=None, client=default_client): """ Run a get request toward given path for configured host. + Args: + path (str): The path to query. + json_response (bool): Whether to return a json response. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. + Returns: The request result. """ @@ -296,6 +370,11 @@ def post(path, data, client=default_client): """ Run a post request toward given path for configured host. + Args: + path (str): The path to query. + data (dict): The data to post. + client (KitsuClient): The client to use for the request. + Returns: The request result. """ @@ -323,6 +402,11 @@ def put(path, data, client=default_client): """ Run a put request toward given path for configured host. + Args: + path (str): The path to query. + data (dict): The data to put. + client (KitsuClient): The client to use for the request. + Returns: The request result. """ @@ -344,6 +428,11 @@ def delete(path, params=None, client=default_client): """ Run a delete request toward given path for configured host. + Args: + path (str): The path to query. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. + Returns: The request result. """ @@ -373,7 +462,7 @@ def get_message_from_response( default_message: str - An optional default value to revert to if no message is detected. Returns: - The message of a given response, or a default message - if any. + str: The message to display to the user. """ message = default_message message_json = response.json() @@ -393,6 +482,8 @@ def check_status(request, path, client=None): Args: request (Request): The request to validate. + path (str): The path of the request. + client (KitsuClient): The client to use for the request. Returns: int: Status code @@ -422,8 +513,13 @@ def check_status(request, path, client=None): ) elif status_code in [401, 422]: try: - if client and client.automatic_refresh_token: - client.refresh_token() + if ( + client + and client.refresh_token + and client.use_refresh_token + and request.json()["message"] == "Signature has expired" + ): + client.refresh_access_token() return status_code, True else: raise NotAuthenticatedException(path) @@ -457,6 +553,8 @@ def fetch_all( """ Args: path (str): The path for which we want to retrieve all entries. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. paginated (bool): Will query entries page by page. limit (int): Limit the number of entries per page. @@ -500,6 +598,8 @@ def fetch_first(path, params=None, client=default_client): """ Args: path (str): The path for which we want to retrieve the first entry. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. Returns: dict: The first entry for which a model is required. @@ -519,6 +619,8 @@ def fetch_one(model_name, id, params=None, client=default_client): Args: model_name (str): Model type name. id (str): Model instance ID. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. Returns: dict: The model instance matching id and model name. @@ -533,8 +635,9 @@ def create(model_name, data, client=default_client): Create an entry for given model and data. Args: - model (str): The model type involved - data (str): The data to use for creation + model_name (str): The model type involved. + data (str): The data to use for creation. + client (KitsuClient): The client to use for the request. Returns: dict: Created entry @@ -547,9 +650,10 @@ def update(model_name, model_id, data, client=default_client): Update an entry for given model, id and data. Args: - model (str): The model type involved - id (str): The target model id - data (dict): The data to update + model_name (str): The model type involved. + model_id (str): The target model id. + data (dict): The data to update. + client (KitsuClient): The client to use for the request. Returns: dict: Updated entry @@ -566,6 +670,9 @@ def upload(path, file_path, data={}, extra_files=[], client=default_client): Args: path (str): The url path to upload file. file_path (str): The file location on the hard drive. + data (dict): The data to send with the file. + extra_files (list): List of extra files to upload. + client (KitsuClient): The client to use for the request. Returns: Response: Request response object. @@ -595,6 +702,17 @@ def upload(path, file_path, data={}, extra_files=[], client=default_client): def _build_file_dict(file_path, extra_files): + """ + Build a dictionary of files to upload. + + Args: + file_path (str): The file location on the hard drive. + extra_files (list): List of extra files to upload. + + Returns: + dict: The dictionary of files to upload. + """ + files = {"file": open(file_path, "rb")} i = 2 for file_path in extra_files: @@ -610,6 +728,8 @@ def download(path, file_path, params=None, client=default_client): Args: path (str): The url path to download file from. file_path (str): The location to store the file on the hard drive. + params (dict): The parameters to pass to the request. + client (KitsuClient): The client to use for the request. Returns: Response: Request response object. @@ -629,6 +749,14 @@ def download(path, file_path, params=None, client=default_client): def get_file_data_from_url(url, full=False, client=default_client): """ Return data found at given url. + + Args: + url (str): The url to fetch data from. + full (bool): Whether to use full url. + client (KitsuClient): The client to use for the request. + + Returns: + bytes: The data found at the given url. """ if not full: url = get_full_url(url) @@ -645,15 +773,26 @@ def get_file_data_from_url(url, full=False, client=default_client): def import_data(model_name, data, client=default_client): """ + Import data for given model. + Args: - model_name (str): The data model to import - data (dict): The data to import + model_name (str): The data model to import. + data (dict): The data to import. + client (KitsuClient): The client to use for the request. + + Returns: + dict: The imported data. """ return post("/import/kitsu/%s" % model_name, data, client=client) def get_api_version(client=default_client): """ + Get the current version of the API. + + Args: + client (KitsuClient): The client to use for the request. + Returns: str: Current version of the API. """ @@ -662,6 +801,11 @@ def get_api_version(client=default_client): def get_current_user(client=default_client): """ + Get the current user. + + Args: + client (KitsuClient): The client to use for the request. + Returns: dict: User database information for user linked to auth tokens. """ diff --git a/gazu/events.py b/gazu/events.py index 9223fb0..13592e6 100644 --- a/gazu/events.py +++ b/gazu/events.py @@ -52,11 +52,13 @@ def init( Returns: Event client that will be able to set listeners. """ - params = {"ssl_verify": ssl_verify} + params = { + "ssl_verify": ssl_verify, + "reconnection": reconnection, + "logger": logger, + } params.update(kwargs) - event_client = socketio.Client( - logger=logger, reconnection=reconnection, **params - ) + event_client = socketio.Client(**params) event_client.on("connect_error", connect_error) event_client.register_namespace(EventsNamespace("/events")) event_client.connect(get_event_host(client), make_auth_header()) diff --git a/setup.cfg b/setup.cfg index 4d7b014..1ec4d37 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,8 +35,6 @@ install_requires = requests>=2.25.1 Deprecated==1.2.15 pywin32>=308; sys_platform == 'win32' and python_version != '2.7' - pyjwt==1.7.1; python_version == '2.7' - pyjwt>=2.4.0; python_version >= '3.6' [options.packages.find] # ignore gazutest directory diff --git a/tests/test_client.py b/tests/test_client.py index e904b6b..1d10373 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,7 @@ -import sys import datetime import json import random import string -import jwt import unittest import requests_mock @@ -72,8 +70,10 @@ def test_set_tokens(self): pass def test_make_auth_header(self): - self.assertEqual(first=raw.default_client.make_auth_header(), - second=raw.make_auth_header()) + self.assertEqual( + raw.default_client.make_auth_header(), + raw.make_auth_header(), + ) def test_url_path_join(self): root = raw.get_host() @@ -257,84 +257,15 @@ def test_version(self): self.assertEqual(raw.get_api_version(), "0.2.0") def test_make_auth_token(self): - if sys.version_info.major == 2: - tokens = {"access_token": "secretaccesstoken"} - else: - tokens = {"access_token": jwt.encode( - payload={'exp': (datetime.datetime.now() + datetime.timedelta(days=30)).timestamp()}, - key='secretkey')} + tokens = {"access_token": "token_test"} raw.set_tokens(tokens) - self.assertEqual(raw.make_auth_header(), - {"Authorization": "Bearer " + tokens["access_token"], - "User-Agent": "CGWire Gazu " + __version__}) - - def test_access_token_has_expired(self): - # Automatic token refresh is not supported in Python 2. - if sys.version_info.major == 2: - return True - - client = raw.KitsuClient(host='http://localhost') - test_cases = {'fresh': (datetime.timedelta(days=30), False), - 'expired': (datetime.timedelta(days=-1), True)} - for testcase in test_cases.items(): - client.access_token = jwt.encode( - payload={'exp': (datetime.datetime.now() + testcase[-1][0]).timestamp()}, - key='secretkey') - - self.assertEqual( - first=client.access_token_has_expired, second=testcase[-1][-1], - msg=testcase[0] + ' Access Token correctly detected.') - - client.access_token = None - self.assertEqual(first=client.access_token_has_expired, second=False, - msg='') - - client.tokens["refresh_token"] = 'placeholder' - self.assertEqual(first=client.access_token_has_expired, second=True) - - def test_automatic_token_refresh(self): - # Automatic token refresh is not supported in Python 2. - if sys.version_info.major == 2: - return True - - def encode(timestamp): - return jwt.encode(payload={'exp': timestamp}, key='secretkey') - - expired_access_token = encode((datetime.datetime.now() + datetime.timedelta(days=-30)).timestamp()) - fresh_access_token = encode((datetime.datetime.now() + datetime.timedelta(days=30)).timestamp()) - new_access_token = encode((datetime.datetime.now() + datetime.timedelta(days=90)).timestamp()) - - client = raw.KitsuClient(host='http://localhost') - client.tokens["refresh_token"] = 'placeholder' - - with requests_mock.Mocker() as mock: - mock_route(mock, "GET", "http://localhost/auth/refresh-token", - text={'access_token': new_access_token}) - mock_route(mock, "GET", "http://localhost/test", text={}) - - client.automatic_refresh_token = False - client.access_token = expired_access_token - client.session.get(raw.get_full_url(path='test', client=client), headers=client.make_auth_header()) - # Expired tokens are not refreshed if automatic_refresh is False. - self.assertEqual(client.access_token, expired_access_token) - - client.access_token = fresh_access_token - client.session.get(raw.get_full_url(path='test', client=client), headers=client.make_auth_header()) - # Fresh tokens are not changed. - self.assertEqual(client.access_token, fresh_access_token) - - client.automatic_refresh_token = True - client.access_token = expired_access_token - client.session.get(raw.get_full_url(path='test', client=client), headers=client.make_auth_header()) - # Expired tokens are updated if automatic_refresh is True - self.assertEqual(client.access_token, new_access_token) - - client.automatic_refresh_token = True - client.access_token = fresh_access_token - client.session.get(raw.get_full_url(path='test', client=client), headers=client.make_auth_header()) - # Fresh tokens are not changed. - self.assertEqual(client.access_token, fresh_access_token) + self.assertEqual( + { + "Authorization": "Bearer token_test", + "User-Agent": "CGWire Gazu %s" % __version__, + }, + ) def test_upload(self): with open("./tests/fixtures/v1.png", "rb") as test_file: @@ -501,14 +432,18 @@ def test_init_send_email_otp(self): self.assertEqual(success, {"success": True}) def test_init_refresh_token(self): - access_token = jwt.encode(payload={'exp': datetime.datetime.now()}, key='secretkey') - with requests_mock.mock() as mock: raw.default_client.tokens["refresh_token"] = "refresh_token1" - mock_route(mock, "GET", "auth/refresh-token", text={"access_token": access_token}) - gazu.refresh_token() - - self.assertEqual(raw.default_client.access_token, access_token) + mock_route( + mock, + "GET", + "auth/refresh-token", + text={"access_token": "tokentest1"}, + ) + gazu.refresh_access_token() + self.assertEqual( + raw.default_client.tokens["access_token"], "tokentest1" + ) def test_init_log_in_fail(self): with requests_mock.mock() as mock: