From 7fb9f10529f6ad5569b4d7f723461a2eaf9f62d1 Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Fri, 10 Jun 2022 16:32:17 -0400 Subject: [PATCH 1/8] add AuthorizedSession support --- nasdaqdatalink/__init__.py | 1 + nasdaqdatalink/api_config.py | 20 ++++++++ nasdaqdatalink/connection.py | 21 +++++--- nasdaqdatalink/get_point_in_time.py | 7 +-- nasdaqdatalink/get_table.py | 8 +-- nasdaqdatalink/model/authorized_session.py | 57 ++++++++++++++++++++++ nasdaqdatalink/model/database.py | 13 ++--- nasdaqdatalink/utils/request_type_util.py | 5 +- 8 files changed, 110 insertions(+), 22 deletions(-) create mode 100644 nasdaqdatalink/model/authorized_session.py diff --git a/nasdaqdatalink/__init__.py b/nasdaqdatalink/__init__.py index d1e4654..7aacf70 100644 --- a/nasdaqdatalink/__init__.py +++ b/nasdaqdatalink/__init__.py @@ -10,6 +10,7 @@ from .model.point_in_time import PointInTime from .model.data import Data from .model.merged_dataset import MergedDataset +from .model.authorized_session import AuthorizedSession from .get import get from .bulkdownload import bulkdownload from .export_table import export_table diff --git a/nasdaqdatalink/api_config.py b/nasdaqdatalink/api_config.py index dea1dd0..974d918 100644 --- a/nasdaqdatalink/api_config.py +++ b/nasdaqdatalink/api_config.py @@ -17,6 +17,18 @@ class ApiConfig: retry_status_codes = [429] + list(range(500, 512)) verify_ssl = True + def read_key(self, filename): + if not os.path.isfile(filename): + raise_empty_file(filename) + + with open(filename, 'r') as f: + apikey = get_first_non_empty(f) + + if not apikey: + raise_empty_file(filename) + + self.api_key = apikey + def create_file(config_filename): # Create the file as well as the parent dir if needed. @@ -102,3 +114,11 @@ def read_key(filename=None): read_key_from_environment_variable() elif config_file_exists(filename): read_key_from_file(filename) + + +def get_config_from_kwargs(kwargs): + params = getattr(kwargs, "params", None) + result = getattr(params, "api_config", None) + if result is None: + result = ApiConfig + return result diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 350ed49..129e1c7 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -8,7 +8,7 @@ from .util import Util from .version import VERSION -from .api_config import ApiConfig +from .api_config import ApiConfig, get_config_from_kwargs from nasdaqdatalink.errors.data_link_error import ( DataLinkError, LimitExceededError, InternalServerError, AuthenticationError, ForbiddenError, InvalidRequestError, @@ -22,31 +22,36 @@ def request(cls, http_verb, url, **options): headers = options['headers'] else: headers = {} + api_config = get_config_from_kwargs(options) accept_value = 'application/json' - if ApiConfig.api_version: - accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version + if api_config.api_version: + accept_value += ", application/vnd.data.nasdaq+json;version=%s" % api_config.api_version headers = Util.merge_to_dicts({'accept': accept_value, 'request-source': 'python', 'request-source-version': VERSION}, headers) - if ApiConfig.api_key: - headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) + if api_config.api_key: + headers = Util.merge_to_dicts({'x-api-token': api_config.api_key}, headers) options['headers'] = headers - abs_url = '%s/%s' % (ApiConfig.api_base, url) + abs_url = '%s/%s' % (api_config.api_base, url) return cls.execute_request(http_verb, abs_url, **options) @classmethod def execute_request(cls, http_verb, url, **options): - session = cls.get_session() + params = getattr(options, 'params', None) + session = getattr(params, 'session', None) + if session is None: + session = cls.get_session() + api_config = get_config_from_kwargs(options) try: response = session.request(method=http_verb, url=url, - verify=ApiConfig.verify_ssl, + verify=api_config.verify_ssl, **options) if response.status_code < 200 or response.status_code >= 300: cls.handle_api_error(response) diff --git a/nasdaqdatalink/get_point_in_time.py b/nasdaqdatalink/get_point_in_time.py index c4f7578..6bfa907 100644 --- a/nasdaqdatalink/get_point_in_time.py +++ b/nasdaqdatalink/get_point_in_time.py @@ -1,6 +1,6 @@ from nasdaqdatalink.model.point_in_time import PointInTime from nasdaqdatalink.errors.data_link_error import LimitExceededError -from .api_config import ApiConfig +from .api_config import get_config_from_kwargs from .message import Message from nasdaqdatalink.errors.data_link_error import InvalidRequestError import warnings @@ -23,6 +23,7 @@ def get_point_in_time(datatable_code, **options): data = None page_count = 0 + api_config = get_config_from_kwargs(options) while True: next_options = copy.deepcopy(options) next_data = PointInTime(datatable_code, pit=pit_options).data(params=next_options) @@ -32,10 +33,10 @@ def get_point_in_time(datatable_code, **options): else: data.extend(next_data) - if page_count >= ApiConfig.page_limit: + if page_count >= api_config.page_limit: raise LimitExceededError( Message.WARN_DATA_LIMIT_EXCEEDED % (datatable_code, - ApiConfig.api_key + api_config.api_key ) ) diff --git a/nasdaqdatalink/get_table.py b/nasdaqdatalink/get_table.py index c07d3c8..6a35988 100644 --- a/nasdaqdatalink/get_table.py +++ b/nasdaqdatalink/get_table.py @@ -1,6 +1,6 @@ from nasdaqdatalink.model.datatable import Datatable from nasdaqdatalink.errors.data_link_error import LimitExceededError -from .api_config import ApiConfig +from .api_config import get_config_from_kwargs from .message import Message import warnings import copy @@ -14,6 +14,8 @@ def get_table(datatable_code, **options): data = None page_count = 0 + api_config = get_config_from_kwargs(options) + while True: next_options = copy.deepcopy(options) next_data = Datatable(datatable_code).data(params=next_options) @@ -23,10 +25,10 @@ def get_table(datatable_code, **options): else: data.extend(next_data) - if page_count >= ApiConfig.page_limit: + if page_count >= api_config.page_limit: raise LimitExceededError( Message.WARN_DATA_LIMIT_EXCEEDED % (datatable_code, - ApiConfig.api_key + api_config.api_key ) ) diff --git a/nasdaqdatalink/model/authorized_session.py b/nasdaqdatalink/model/authorized_session.py new file mode 100644 index 0000000..7c2fe17 --- /dev/null +++ b/nasdaqdatalink/model/authorized_session.py @@ -0,0 +1,57 @@ +from nasdaqdatalink.api_config import ApiConfig +from nasdaqdatalink.get import get +from nasdaqdatalink.bulkdownload import bulkdownload +from nasdaqdatalink.export_table import export_table +from nasdaqdatalink.get_table import get_table +from nasdaqdatalink.get_point_in_time import get_point_in_time +from urllib3.util.retry import Retry +from requests.adapters import HTTPAdapter +import requests +import urllib + + +def get_retries(api_config): + if isinstance(api_config, ApiConfig): + if not api_config.use_retries: + return Retry(total=0) + + Retry.BACKOFF_MAX = api_config.max_wait_between_retries + retries = Retry(total=api_config.number_of_retries, + connect=api_config.number_of_retries, + read=api_config.number_of_retries, + status_forcelist=api_config.retry_status_codes, + backoff_factor=api_config.retry_backoff_factor, + raise_on_status=False) + return retries + + +class AuthorizedSession: + def __init__(self, api_config=ApiConfig) -> None: + super(AuthorizedSession, self).__init__() + self._api_config = api_config + self._auth_session = requests.Session() + retries = get_retries(self._api_config) + adapter = HTTPAdapter(max_retries=retries) + self._auth_session.mount(api_config.api_protocol, adapter) + + proxies = urllib.request.getproxies() + if proxies is not None: + self._auth_session.proxies.update(proxies) + + def get(self, dataset, **kwargs): + get(dataset, session=self._auth_session, api_config=self._api_config, **kwargs) + + def bulkdownload(self, database, **kwargs): + bulkdownload(database, session=self._auth_session, api_config=self._api_config, **kwargs) + + def export_table(self, datatable_code, **kwargs): + export_table(datatable_code, session=self._auth_session, + api_config=self._api_config, **kwargs) + + def get_table(self, datatable_code, **options): + get_table(datatable_code, session=self._auth_session, + api_config=self._api_config, **options) + + def get_point_in_time(self, datatable_code, **options): + get_point_in_time(datatable_code, session=self._auth_session, + api_config=self._api_config, **options) diff --git a/nasdaqdatalink/model/database.py b/nasdaqdatalink/model/database.py index 870dedc..443daf9 100644 --- a/nasdaqdatalink/model/database.py +++ b/nasdaqdatalink/model/database.py @@ -3,7 +3,7 @@ from six.moves.urllib.parse import urlencode, urlparse import nasdaqdatalink.model.dataset -from nasdaqdatalink.api_config import ApiConfig +from nasdaqdatalink.api_config import get_config_from_kwargs from nasdaqdatalink.connection import Connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message @@ -21,15 +21,16 @@ def get_code_from_meta(cls, metadata): return metadata['database_code'] def bulk_download_url(self, **options): + api_config = get_config_from_kwargs(options) url = self._bulk_download_path() - url = ApiConfig.api_base + '/' + url + url = api_config.api_base + '/' + url if 'params' not in options: options['params'] = {} - if ApiConfig.api_key: - options['params']['api_key'] = ApiConfig.api_key - if ApiConfig.api_version: - options['params']['api_version'] = ApiConfig.api_version + if api_config.api_key: + options['params']['api_key'] = api_config.api_key + if api_config.api_version: + options['params']['api_version'] = api_config.api_version if list(options.keys()): url += '?' + urlencode(options['params']) diff --git a/nasdaqdatalink/utils/request_type_util.py b/nasdaqdatalink/utils/request_type_util.py index a53af61..97d63cc 100644 --- a/nasdaqdatalink/utils/request_type_util.py +++ b/nasdaqdatalink/utils/request_type_util.py @@ -1,5 +1,5 @@ from urllib.parse import urlencode -from nasdaqdatalink.api_config import ApiConfig +from nasdaqdatalink.api_config import get_config_from_kwargs class RequestType(object): @@ -13,7 +13,8 @@ class RequestType(object): @classmethod def get_request_type(cls, url, **params): query_string = urlencode(params['params']) - request_url = '%s/%s/%s' % (ApiConfig.api_base, url, query_string) + api_config = get_config_from_kwargs(params) + request_url = '%s/%s/%s' % (api_config.api_base, url, query_string) if RequestType.USE_GET_REQUEST and (len(request_url) < cls.MAX_URL_LENGTH_FOR_GET): return 'get' else: From 784a11f226885caebb90c80d4dba2ed4914075b8 Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Mon, 13 Jun 2022 16:39:30 -0400 Subject: [PATCH 2/8] add test for api_config --- nasdaqdatalink/api_config.py | 13 +++-- nasdaqdatalink/model/authorized_session.py | 26 +++++----- test/test_api_config.py | 55 ++++++++++++++++++++++ test/test_authorized_session.py | 24 ++++++++++ 4 files changed, 101 insertions(+), 17 deletions(-) create mode 100644 test/test_authorized_session.py diff --git a/nasdaqdatalink/api_config.py b/nasdaqdatalink/api_config.py index 974d918..d86d576 100644 --- a/nasdaqdatalink/api_config.py +++ b/nasdaqdatalink/api_config.py @@ -17,7 +17,7 @@ class ApiConfig: retry_status_codes = [429] + list(range(500, 512)) verify_ssl = True - def read_key(self, filename): + def read_key(self, filename=None): if not os.path.isfile(filename): raise_empty_file(filename) @@ -117,8 +117,11 @@ def read_key(filename=None): def get_config_from_kwargs(kwargs): - params = getattr(kwargs, "params", None) - result = getattr(params, "api_config", None) - if result is None: - result = ApiConfig + result = ApiConfig + if isinstance(kwargs, dict): + params = kwargs.get('params') + if isinstance(params, dict): + result = params.get('api_config') + if not isinstance(result, ApiConfig): + result = ApiConfig return result diff --git a/nasdaqdatalink/model/authorized_session.py b/nasdaqdatalink/model/authorized_session.py index 7c2fe17..ad2924e 100644 --- a/nasdaqdatalink/model/authorized_session.py +++ b/nasdaqdatalink/model/authorized_session.py @@ -10,24 +10,26 @@ import urllib -def get_retries(api_config): - if isinstance(api_config, ApiConfig): - if not api_config.use_retries: - return Retry(total=0) - - Retry.BACKOFF_MAX = api_config.max_wait_between_retries - retries = Retry(total=api_config.number_of_retries, - connect=api_config.number_of_retries, - read=api_config.number_of_retries, - status_forcelist=api_config.retry_status_codes, - backoff_factor=api_config.retry_backoff_factor, - raise_on_status=False) +def get_retries(api_config=ApiConfig): + retries = None + if not api_config.use_retries: + return Retry(total=0) + + Retry.BACKOFF_MAX = api_config.max_wait_between_retries + retries = Retry(total=api_config.number_of_retries, + connect=api_config.number_of_retries, + read=api_config.number_of_retries, + status_forcelist=api_config.retry_status_codes, + backoff_factor=api_config.retry_backoff_factor, + raise_on_status=False) return retries class AuthorizedSession: def __init__(self, api_config=ApiConfig) -> None: super(AuthorizedSession, self).__init__() + if not isinstance(api_config, ApiConfig): + api_config = ApiConfig self._api_config = api_config self._auth_session = requests.Session() retries = get_retries(self._api_config) diff --git a/test/test_api_config.py b/test/test_api_config.py index 2c183b5..c21efe0 100644 --- a/test/test_api_config.py +++ b/test/test_api_config.py @@ -132,3 +132,58 @@ def test_read_key_from_file_with_tab(self): def test_read_key_from_file_with_multi_newline(self): given = "keyfordefaultfile\n\nanotherkey\n" self._read_key_from_file_helper(given, TEST_DEFAULT_FILE_CONTENTS) + + def test_default_instance_will_have_share_values_with_singleton(self): + os.environ['NASDAQ_DATA_LINK_API_KEY'] = 'setinenv' + ApiConfig.api_key = None + read_key() + api_config = ApiConfig() + self.assertEqual(api_config.api_key, "setinenv") + # make sure change in instance will not affect the singleton + api_config.api_key = None + self.assertEqual(ApiConfig.api_key, "setinenv") + + def test_get_config_from_kwargs_return_api_config_if_present(self): + api_config = get_config_from_kwargs({ + 'params': { + 'api_config': ApiConfig() + } + }) + self.assertTrue(isinstance(api_config, ApiConfig)) + + def test_get_config_from_kwargs_return_singleton_if_not_present_or_wrong_type(self): + api_config = get_config_from_kwargs(None) + self.assertTrue(issubclass(api_config, ApiConfig)) + self.assertFalse(isinstance(api_config, ApiConfig)) + api_config = get_config_from_kwargs(1) + self.assertTrue(issubclass(api_config, ApiConfig)) + self.assertFalse(isinstance(api_config, ApiConfig)) + api_config = get_config_from_kwargs({ + 'params': None + }) + self.assertTrue(issubclass(api_config, ApiConfig)) + self.assertFalse(isinstance(api_config, ApiConfig)) + + def test_instance_read_key_should_raise_error(self): + api_config = ApiConfig() + with self.assertRaises(TypeError): + api_config.read_key(None) + with self.assertRaises(ValueError): + api_config.read_key('') + + def test_instance_read_key_should_raise_error_when_empty(self): + save_key("", TEST_KEY_FILE) + api_config = ApiConfig() + with self.assertRaises(ValueError): + # read empty file + api_config.read_key(TEST_KEY_FILE) + + def test_instance_read_the_right_key(self): + expected_key = 'ilovepython' + save_key(expected_key, TEST_KEY_FILE) + api_config = ApiConfig() + api_config.api_key = '' + api_config.read_key(TEST_KEY_FILE) + self.assertEqual(ApiConfig.api_key, expected_key) + + diff --git a/test/test_authorized_session.py b/test/test_authorized_session.py new file mode 100644 index 0000000..474cb4d --- /dev/null +++ b/test/test_authorized_session.py @@ -0,0 +1,24 @@ +from unittest import TestCase +from unittest.mock import patch + +from nasdaqdatalink.model.authorized_session import AuthorizedSession +from nasdaqdatalink.api_config import ApiConfig +from requests.sessions import Session +from requests.adapters import HTTPAdapter + + +class AuthorizedSessionTest(TestCase): + def test_authorized_session_assign_correct_internal_config(self): + authed_session = AuthorizedSession() + self.assertTrue(issubclass(authed_session._api_config, ApiConfig)) + authed_session = AuthorizedSession(None) + self.assertTrue(issubclass(authed_session._api_config, ApiConfig)) + api_config = ApiConfig() + authed_session = AuthorizedSession(api_config) + self.assertTrue(isinstance(authed_session._api_config, ApiConfig)) + + def test_authorized_session_pass_created_session(self): + authed_session = AuthorizedSession() + self.assertTrue(isinstance(authed_session._auth_session, Session)) + adapter = authed_session._auth_session.get_adapter(ApiConfig.api_protocol) + self.assertTrue(isinstance(adapter, HTTPAdapter)) From bfaa3b05ede7ae38501ff28b5da1b21500d12bef Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Wed, 15 Jun 2022 09:45:44 -0400 Subject: [PATCH 3/8] fix session object not getting through --- nasdaqdatalink/connection.py | 3 +-- test/test_authorized_session.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 129e1c7..c13fa37 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -42,8 +42,7 @@ def request(cls, http_verb, url, **options): @classmethod def execute_request(cls, http_verb, url, **options): - params = getattr(options, 'params', None) - session = getattr(params, 'session', None) + session = options.get('params', {}).get('session', None) if session is None: session = cls.get_session() diff --git a/test/test_authorized_session.py b/test/test_authorized_session.py index 474cb4d..d69237b 100644 --- a/test/test_authorized_session.py +++ b/test/test_authorized_session.py @@ -1,6 +1,4 @@ from unittest import TestCase -from unittest.mock import patch - from nasdaqdatalink.model.authorized_session import AuthorizedSession from nasdaqdatalink.api_config import ApiConfig from requests.sessions import Session @@ -18,7 +16,10 @@ def test_authorized_session_assign_correct_internal_config(self): self.assertTrue(isinstance(authed_session._api_config, ApiConfig)) def test_authorized_session_pass_created_session(self): + ApiConfig.use_retries = True + ApiConfig.number_of_retries = 130 authed_session = AuthorizedSession() self.assertTrue(isinstance(authed_session._auth_session, Session)) adapter = authed_session._auth_session.get_adapter(ApiConfig.api_protocol) self.assertTrue(isinstance(adapter, HTTPAdapter)) + self.assertEqual(adapter.max_retries.connect, 130) From b7dca2b2e70145a73423881e05e4dcc61e4264b4 Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Wed, 15 Jun 2022 10:17:09 -0400 Subject: [PATCH 4/8] clean request payload --- nasdaqdatalink/connection.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index c13fa37..fd45048 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -47,6 +47,10 @@ def execute_request(cls, http_verb, url, **options): session = cls.get_session() api_config = get_config_from_kwargs(options) + + # clean the request payload + options.get('params', {}).pop('session', None) + options.get('params', {}).pop('api_config', None) try: response = session.request(method=http_verb, url=url, From 01e4a11ab5cee6d2942e2e29560351a9017ab0c0 Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Wed, 15 Jun 2022 10:33:36 -0400 Subject: [PATCH 5/8] fix get api_config before params assignment --- nasdaqdatalink/get_point_in_time.py | 4 ++-- nasdaqdatalink/get_table.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/nasdaqdatalink/get_point_in_time.py b/nasdaqdatalink/get_point_in_time.py index 6bfa907..73d6ed2 100644 --- a/nasdaqdatalink/get_point_in_time.py +++ b/nasdaqdatalink/get_point_in_time.py @@ -1,6 +1,6 @@ from nasdaqdatalink.model.point_in_time import PointInTime from nasdaqdatalink.errors.data_link_error import LimitExceededError -from .api_config import get_config_from_kwargs +from .api_config import ApiConfig from .message import Message from nasdaqdatalink.errors.data_link_error import InvalidRequestError import warnings @@ -23,7 +23,7 @@ def get_point_in_time(datatable_code, **options): data = None page_count = 0 - api_config = get_config_from_kwargs(options) + api_config = options.get('api_config', ApiConfig) while True: next_options = copy.deepcopy(options) next_data = PointInTime(datatable_code, pit=pit_options).data(params=next_options) diff --git a/nasdaqdatalink/get_table.py b/nasdaqdatalink/get_table.py index 6a35988..32188bb 100644 --- a/nasdaqdatalink/get_table.py +++ b/nasdaqdatalink/get_table.py @@ -1,6 +1,6 @@ from nasdaqdatalink.model.datatable import Datatable from nasdaqdatalink.errors.data_link_error import LimitExceededError -from .api_config import get_config_from_kwargs +from .api_config import ApiConfig from .message import Message import warnings import copy @@ -14,8 +14,7 @@ def get_table(datatable_code, **options): data = None page_count = 0 - api_config = get_config_from_kwargs(options) - + api_config = options.get('api_config', ApiConfig) while True: next_options = copy.deepcopy(options) next_data = Datatable(datatable_code).data(params=next_options) From 1a990da91aca7aa89ecb72da114367cbf67ae33f Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Thu, 16 Jun 2022 09:51:24 -0400 Subject: [PATCH 6/8] add test and fix lint --- nasdaqdatalink/model/authorized_session.py | 26 ++++++----- test/test_authorized_session.py | 43 ++++++++++++++++++- test/test_connection.py | 50 +++++++++++++++++++++- 3 files changed, 101 insertions(+), 18 deletions(-) diff --git a/nasdaqdatalink/model/authorized_session.py b/nasdaqdatalink/model/authorized_session.py index ad2924e..7c2c9e2 100644 --- a/nasdaqdatalink/model/authorized_session.py +++ b/nasdaqdatalink/model/authorized_session.py @@ -1,16 +1,12 @@ +import nasdaqdatalink from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.get import get -from nasdaqdatalink.bulkdownload import bulkdownload -from nasdaqdatalink.export_table import export_table -from nasdaqdatalink.get_table import get_table -from nasdaqdatalink.get_point_in_time import get_point_in_time from urllib3.util.retry import Retry from requests.adapters import HTTPAdapter import requests import urllib -def get_retries(api_config=ApiConfig): +def get_retries(api_config=nasdaqdatalink.ApiConfig): retries = None if not api_config.use_retries: return Retry(total=0) @@ -41,19 +37,21 @@ def __init__(self, api_config=ApiConfig) -> None: self._auth_session.proxies.update(proxies) def get(self, dataset, **kwargs): - get(dataset, session=self._auth_session, api_config=self._api_config, **kwargs) + nasdaqdatalink.get(dataset, session=self._auth_session, + api_config=self._api_config, **kwargs) def bulkdownload(self, database, **kwargs): - bulkdownload(database, session=self._auth_session, api_config=self._api_config, **kwargs) + nasdaqdatalink.bulkdownload(database, session=self._auth_session, + api_config=self._api_config, **kwargs) def export_table(self, datatable_code, **kwargs): - export_table(datatable_code, session=self._auth_session, - api_config=self._api_config, **kwargs) + nasdaqdatalink.export_table(datatable_code, session=self._auth_session, + api_config=self._api_config, **kwargs) def get_table(self, datatable_code, **options): - get_table(datatable_code, session=self._auth_session, - api_config=self._api_config, **options) + nasdaqdatalink.get_table(datatable_code, session=self._auth_session, + api_config=self._api_config, **options) def get_point_in_time(self, datatable_code, **options): - get_point_in_time(datatable_code, session=self._auth_session, - api_config=self._api_config, **options) + nasdaqdatalink.get_point_in_time(datatable_code, session=self._auth_session, + api_config=self._api_config, **options) diff --git a/test/test_authorized_session.py b/test/test_authorized_session.py index d69237b..60f20ba 100644 --- a/test/test_authorized_session.py +++ b/test/test_authorized_session.py @@ -1,11 +1,12 @@ -from unittest import TestCase +import unittest from nasdaqdatalink.model.authorized_session import AuthorizedSession from nasdaqdatalink.api_config import ApiConfig from requests.sessions import Session from requests.adapters import HTTPAdapter +from mock import patch -class AuthorizedSessionTest(TestCase): +class AuthorizedSessionTest(unittest.TestCase): def test_authorized_session_assign_correct_internal_config(self): authed_session = AuthorizedSession() self.assertTrue(issubclass(authed_session._api_config, ApiConfig)) @@ -23,3 +24,41 @@ def test_authorized_session_pass_created_session(self): adapter = authed_session._auth_session.get_adapter(ApiConfig.api_protocol) self.assertTrue(isinstance(adapter, HTTPAdapter)) self.assertEqual(adapter.max_retries.connect, 130) + + @patch("nasdaqdatalink.get") + def test_call_get_with_session_and_api_config(self, mock): + api_config = ApiConfig() + authed_session = AuthorizedSession(api_config) + authed_session.get('WIKI/AAPL') + mock.assert_called_with('WIKI/AAPL', api_config=api_config, + session=authed_session._auth_session) + + @patch("nasdaqdatalink.bulkdownload") + def test_call_bulkdownload_with_session_and_api_config(self, mock): + api_config = ApiConfig() + authed_session = AuthorizedSession(api_config) + authed_session.bulkdownload('NSE') + mock.assert_called_with('NSE', api_config=api_config, + session=authed_session._auth_session) + + @patch("nasdaqdatalink.export_table") + def test_call_export_table_with_session_and_api_config(self, mock): + authed_session = AuthorizedSession() + authed_session.export_table('WIKI/AAPL') + mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig, + session=authed_session._auth_session) + + @patch("nasdaqdatalink.get_table") + def test_call_get_table_with_session_and_api_config(self, mock): + authed_session = AuthorizedSession() + authed_session.get_table('WIKI/AAPL') + mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig, + session=authed_session._auth_session) + + @patch("nasdaqdatalink.get_point_in_time") + def test_call_get_point_in_time_with_session_and_api_config(self, mock): + authed_session = AuthorizedSession() + authed_session.get_point_in_time('DATABASE/CODE', interval='asofdate', date='2020-01-01') + mock.assert_called_with('DATABASE/CODE', interval='asofdate', + date='2020-01-01', api_config=ApiConfig, + session=authed_session._auth_session) diff --git a/test/test_connection.py b/test/test_connection.py index 7777d6e..3d6d1bc 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -6,6 +6,7 @@ NotFoundError, ServiceUnavailableError) from test.test_retries import ModifyRetrySettingsTestCase from test.helpers.httpretty_extension import httpretty +import requests import json from mock import patch, call from nasdaqdatalink.version import VERSION @@ -59,8 +60,8 @@ def test_non_data_link_error(self, request_method): httpretty.register_uri(getattr(httpretty, request_method), "https://data.nasdaq.com/api/v3/databases", body=json.dumps( - {'foobar': - {'code': 'blah', 'message': 'something went wrong'}}), status=500) + {'foobar': + {'code': 'blah', 'message': 'something went wrong'}}), status=500) self.assertRaises( DataLinkError, lambda: Connection.request(request_method, 'databases')) @@ -81,3 +82,48 @@ def test_build_request(self, request_method, mock): 'request-source-version': VERSION}, params={'per_page': 10, 'page': 2}) self.assertEqual(mock.call_args, expected) + + @parameterized.expand(['GET', 'POST']) + @patch('nasdaqdatalink.connection.Connection.execute_request') + def test_build_request_with_custom_api_config(self, request_method, mock): + ApiConfig.api_key = 'api_token' + ApiConfig.api_version = '2015-04-09' + api_config = ApiConfig() + api_config.api_key = 'custom_api_token' + api_config.api_version = '2022-06-09' + session = requests.session() + params = {'per_page': 10, 'page': 2, 'api_config': api_config, 'session': session} + headers = {'x-custom-header': 'header value'} + Connection.request(request_method, 'databases', headers=headers, params=params) + expected = call(request_method, 'https://data.nasdaq.com/api/v3/databases', + headers={'x-custom-header': 'header value', + 'x-api-token': 'custom_api_token', + 'accept': ('application/json, ' + 'application/vnd.data.nasdaq+json;version=2022-06-09'), + 'request-source': 'python', + 'request-source-version': VERSION}, + params={'per_page': 10, 'page': 2, + 'session': session, 'api_config': api_config}) + self.assertEqual(mock.call_args, expected) + + def test_remove_session_and_api_config_param(self): + ApiConfig.api_key = 'api_token' + ApiConfig.api_version = '2015-04-09' + ApiConfig.verify_ssl = True + api_config = ApiConfig() + api_config.api_key = 'custom_api_token' + api_config.api_version = '2022-06-09' + api_config.verify_ssl = False + session = requests.Session() + params = {'per_page': 10, 'page': 2, 'api_config': api_config, 'session': session} + headers = {'x-custom-header': 'header value'} + dummy_response = requests.Response() + dummy_response.status_code = 200 + with patch.object(session, 'request', return_value=dummy_response) as mock: + Connection.execute_request( + 'GET', 'https://data.nasdaq.com/api/v3/databases', headers=headers, params=params) + mock.assert_called_once_with(method='GET', + url='https://data.nasdaq.com/api/v3/databases', + verify=False, + headers=headers, + params={'per_page': 10, 'page': 2}) From 78cacaecb95ce7581f0487adc07d824eea6914b0 Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Thu, 16 Jun 2022 16:10:15 -0400 Subject: [PATCH 7/8] add strip function to remove property in option --- nasdaqdatalink/connection.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index fd45048..d820d24 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -14,6 +14,11 @@ AuthenticationError, ForbiddenError, InvalidRequestError, NotFoundError, ServiceUnavailableError) +KW_TO_REMOVE = [ + 'session', + 'api_config' +] + class Connection: @classmethod @@ -48,9 +53,7 @@ def execute_request(cls, http_verb, url, **options): api_config = get_config_from_kwargs(options) - # clean the request payload - options.get('params', {}).pop('session', None) - options.get('params', {}).pop('api_config', None) + cls.options_kw_strip(options) try: response = session.request(method=http_verb, url=url, @@ -126,3 +129,8 @@ def handle_api_error(cls, resp): klass = d_klass.get(code_letter, DataLinkError) raise klass(message, resp.status_code, resp.text, resp.headers, code) + + @classmethod + def options_kw_strip(self, options): + for kw in KW_TO_REMOVE: + options.get('params', {}).pop(kw, None) From e87eee2d202242bea94cf09a60658388d070bf31 Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Fri, 12 Aug 2022 11:10:29 -0400 Subject: [PATCH 8/8] fix lint, add doc for AuthorizedSession --- README.md | 12 ++++++++++++ test/test_connection.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dd3b199..1b61993 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,18 @@ data_link_log = logging.getLogger("nasdaqdatalink") data_link_log.setLevel(logging.DEBUG) ``` +### Session + +By default, every API request will create a new session; This will have a performance impact when you wish to make multiple requests(see #16). You can use `AuthorizedSession` to take advantage of the reusing session: + +```python +import nasdaqdatalink +session = nasdaqdatalink.AuthorizedSession() +data1 = session.get_table('ZACKS/FC', ticker='AAPL') +data2 = session.get_table('ZACKS/FC', ticker='MFST') +data3 = session.get_table('ZACKS/FC', ticker='NVDA') +``` + ### Detailed Usage Our API can provide more than just data. It can also be used to search and provide metadata or to programmatically retrieve data. For these more advanced techniques please follow our [Detailed Method Guide](./FOR_DEVELOPERS.md). diff --git a/test/test_connection.py b/test/test_connection.py index 3d6d1bc..f8a513f 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -60,8 +60,8 @@ def test_non_data_link_error(self, request_method): httpretty.register_uri(getattr(httpretty, request_method), "https://data.nasdaq.com/api/v3/databases", body=json.dumps( - {'foobar': - {'code': 'blah', 'message': 'something went wrong'}}), status=500) + {'foobar': + {'code': 'blah', 'message': 'something went wrong'}}), status=500) self.assertRaises( DataLinkError, lambda: Connection.request(request_method, 'databases'))