From 24e8ff31fee358bd0a8e6106af76ab3c57cd2873 Mon Sep 17 00:00:00 2001 From: "Economou, Matthew (NIH/NIAID) [C]" Date: Wed, 24 Jan 2024 23:08:02 -0500 Subject: [PATCH] refactor: sort imports and apply the Black style --- setup.py | 26 +- src/satosa/attribute_mapping.py | 89 ++- src/satosa/backends/apple.py | 10 +- src/satosa/backends/bitbucket.py | 61 +- src/satosa/backends/github.py | 56 +- src/satosa/backends/idpy_oidc.py | 24 +- src/satosa/backends/linkedin.py | 67 +-- src/satosa/backends/oauth.py | 116 ++-- src/satosa/backends/openid_connect.py | 120 ++-- src/satosa/backends/orcid.py | 68 +-- src/satosa/backends/reflector.py | 5 +- src/satosa/backends/saml2.py | 249 +++++---- src/satosa/base.py | 170 +++--- src/satosa/context.py | 17 +- src/satosa/cookies.py | 1 - src/satosa/exception.py | 11 + src/satosa/frontends/openid_connect.py | 184 +++--- src/satosa/frontends/ping.py | 5 +- src/satosa/frontends/saml2.py | 439 +++++++++------ src/satosa/metadata_creation/saml_metadata.py | 101 +++- src/satosa/micro_services/account_linking.py | 70 ++- .../micro_services/attribute_authorization.py | 29 +- .../micro_services/attribute_generation.py | 137 ++--- .../micro_services/attribute_modifications.py | 54 +- src/satosa/micro_services/attribute_policy.py | 2 +- .../micro_services/attribute_processor.py | 12 +- src/satosa/micro_services/base.py | 2 + src/satosa/micro_services/consent.py | 63 ++- src/satosa/micro_services/custom_logging.py | 60 +- src/satosa/micro_services/custom_routing.py | 77 +-- src/satosa/micro_services/disco.py | 29 +- src/satosa/micro_services/hasher.py | 5 +- src/satosa/micro_services/idp_hinting.py | 7 +- .../micro_services/ldap_attribute_store.py | 23 +- .../micro_services/primary_identifier.py | 166 ++++-- .../processors/gender_processor.py | 13 +- .../processors/hash_processor.py | 23 +- .../processors/regex_sub_processor.py | 39 +- .../processors/scope_extractor_processor.py | 28 +- .../processors/scope_processor.py | 9 +- .../processors/scope_remover_processor.py | 8 +- src/satosa/plugin_loader.py | 132 +++-- src/satosa/proxy_server.py | 29 +- src/satosa/response.py | 7 +- src/satosa/routing.py | 49 +- src/satosa/saml_util.py | 2 +- src/satosa/satosa_config.py | 33 +- src/satosa/scripts/satosa_saml_metadata.py | 99 +++- src/satosa/state.py | 34 +- src/satosa/util.py | 9 +- src/satosa/version.py | 4 +- src/satosa/wsgi.py | 4 +- tests/conftest.py | 207 ++++--- tests/flows/test_account_linking.py | 20 +- tests/flows/test_consent.py | 16 +- tests/flows/test_oidc-saml.py | 238 +++++--- tests/flows/test_saml-oidc.py | 51 +- tests/flows/test_saml-saml.py | 54 +- tests/flows/test_wsgi_flow.py | 18 +- tests/satosa/backends/test_bitbucket.py | 145 +++-- tests/satosa/backends/test_idpy_oidc.py | 99 ++-- tests/satosa/backends/test_oauth.py | 120 ++-- tests/satosa/backends/test_openid_connect.py | 184 ++++-- tests/satosa/backends/test_orcid.py | 78 +-- tests/satosa/backends/test_saml2.py | 250 ++++++--- tests/satosa/frontends/test_openid_connect.py | 269 ++++++--- tests/satosa/frontends/test_saml2.py | 522 +++++++++++++----- .../metadata_creation/test_description.py | 15 +- .../metadata_creation/test_saml_metadata.py | 340 ++++++++---- .../micro_services/test_account_linking.py | 73 ++- .../test_attribute_authorization.py | 84 ++- .../test_attribute_generation.py | 60 +- .../test_attribute_modifications.py | 238 +++----- .../micro_services/test_attribute_policy.py | 19 +- tests/satosa/micro_services/test_consent.py | 224 +++++--- .../micro_services/test_custom_routing.py | 97 ++-- tests/satosa/micro_services/test_disco.py | 15 +- .../satosa/micro_services/test_idp_hinting.py | 29 +- .../test_ldap_attribute_store.py | 112 ++-- .../scripts/test_satosa_saml_metadata.py | 127 +++-- tests/satosa/test_attribute_mapping.py | 207 +++---- tests/satosa/test_base.py | 46 +- tests/satosa/test_plugin_loader.py | 8 +- tests/satosa/test_response.py | 11 +- tests/satosa/test_routing.py | 89 ++- tests/satosa/test_satosa_config.py | 37 +- tests/satosa/test_state.py | 43 +- tests/users.py | 2 +- tests/util.py | 185 ++++--- 89 files changed, 4599 insertions(+), 2810 deletions(-) diff --git a/setup.py b/setup.py index 51bb389ea..5380483f5 100644 --- a/setup.py +++ b/setup.py @@ -2,18 +2,18 @@ setup.py """ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( - name='SATOSA', - version='8.4.0', - description='Protocol proxy (SAML/OIDC).', - author='DIRG', - author_email='satosa-dev@lists.sunet.se', - license='Apache 2.0', - url='https://github.com/SUNET/SATOSA', - packages=find_packages('src/'), - package_dir={'': 'src'}, + name="SATOSA", + version="8.4.0", + description="Protocol proxy (SAML/OIDC).", + author="DIRG", + author_email="satosa-dev@lists.sunet.se", + license="Apache 2.0", + url="https://github.com/SUNET/SATOSA", + packages=find_packages("src/"), + package_dir={"": "src"}, install_requires=[ "pyop >= v3.4.0", "pysaml2 >= 6.5.1", @@ -44,6 +44,8 @@ "Programming Language :: Python :: 3.11", ], entry_points={ - "console_scripts": ["satosa-saml-metadata=satosa.scripts.satosa_saml_metadata:construct_saml_metadata"] - } + "console_scripts": [ + "satosa-saml-metadata=satosa.scripts.satosa_saml_metadata:construct_saml_metadata" + ] + }, ) diff --git a/src/satosa/attribute_mapping.py b/src/satosa/attribute_mapping.py index d5745864c..124e9594e 100644 --- a/src/satosa/attribute_mapping.py +++ b/src/satosa/attribute_mapping.py @@ -14,9 +14,9 @@ def scope(s): :param s: string to extract scope from (filtered string in mako template) :return: the scope """ - if '@' not in s: + if "@" not in s: raise ValueError("Unscoped string") - (local_part, _, domain_part) = s.partition('@') + (local_part, _, domain_part) = s.partition("@") return domain_part @@ -31,8 +31,12 @@ def __init__(self, internal_attributes): :param internal_attributes: A map of how to convert the attributes (dict[internal_name, dict[attribute_profile, external_name]]) """ - self.separator = "." # separator for nested attribute values, e.g. address.street_address - self.multivalue_separator = ";" # separates multiple values, e.g. when using templates + self.separator = ( + "." # separator for nested attribute values, e.g. address.street_address + ) + self.multivalue_separator = ( + ";" # separates multiple values, e.g. when using templates + ) self.from_internal_attributes = internal_attributes["attributes"] self.template_attributes = internal_attributes.get("template_attributes", None) @@ -40,7 +44,9 @@ def __init__(self, internal_attributes): for internal_attribute_name, mappings in self.from_internal_attributes.items(): for profile, external_attribute_names in mappings.items(): for external_attribute_name in external_attribute_names: - self.to_internal_attributes[profile][external_attribute_name] = internal_attribute_name + self.to_internal_attributes[profile][ + external_attribute_name + ] = internal_attribute_name def to_internal_filter(self, attribute_profile, external_attribute_names): """ @@ -59,7 +65,11 @@ def to_internal_filter(self, attribute_profile, external_attribute_names): try: profile_mapping = self.to_internal_attributes[attribute_profile] except KeyError: - logline = "no attribute mapping found for the given attribute profile {}".format(attribute_profile) + logline = ( + "no attribute mapping found for the given attribute profile {}".format( + attribute_profile + ) + ) logger.warn(logline) # no attributes since the given profile is not configured return [] @@ -103,7 +113,9 @@ def to_internal(self, attribute_profile, external_dict): ) if attribute_values: # Only insert key if it has some values logline = "backend attribute {external} mapped to {internal} ({value})".format( - external=external_attribute_name, internal=internal_attribute_name, value=attribute_values + external=external_attribute_name, + internal=internal_attribute_name, + value=attribute_values, ) logger.debug(logline) internal_dict[internal_attribute_name] = attribute_values @@ -112,7 +124,9 @@ def to_internal(self, attribute_profile, external_dict): external_attribute_name ) logger.debug(logline) - internal_dict = self._handle_template_attributes(attribute_profile, internal_dict) + internal_dict = self._handle_template_attributes( + attribute_profile, internal_dict + ) return internal_dict def _collate_attribute_values_by_priority_order(self, attribute_names, data): @@ -128,7 +142,11 @@ def _collate_attribute_values_by_priority_order(self, attribute_names, data): return result def _render_attribute_template(self, template, data): - t = Template(template, cache_enabled=True, imports=["from satosa.attribute_mapping import scope"]) + t = Template( + template, + cache_enabled=True, + imports=["from satosa.attribute_mapping import scope"], + ) try: return t.render(**data).split(self.multivalue_separator) except (NameError, TypeError): @@ -144,11 +162,19 @@ def _handle_template_attributes(self, attribute_profile, internal_dict): continue external_attribute_name = mapping[attribute_profile] - templates = [t for t in external_attribute_name if "$" in t] # these looks like templates... - template_attribute_values = [self._render_attribute_template(template, internal_dict) for template in - templates] - flattened_attribute_values = list(chain.from_iterable(template_attribute_values)) - attribute_values = flattened_attribute_values or internal_dict.get(internal_attribute_name, None) + templates = [ + t for t in external_attribute_name if "$" in t + ] # these looks like templates... + template_attribute_values = [ + self._render_attribute_template(template, internal_dict) + for template in templates + ] + flattened_attribute_values = list( + chain.from_iterable(template_attribute_values) + ) + attribute_values = flattened_attribute_values or internal_dict.get( + internal_attribute_name, None + ) if attribute_values: # only insert key if it has some values internal_dict[internal_attribute_name] = attribute_values @@ -172,7 +198,9 @@ def _create_nested_attribute_value(self, nested_attribute_names, value): return {nested_attribute_names[0]: value} # keep digging further into the nested attribute names - child_dict = self._create_nested_attribute_value(nested_attribute_names[1:], value) + child_dict = self._create_nested_attribute_value( + nested_attribute_names[1:], value + ) return {nested_attribute_names[0]: child_dict} def from_internal(self, attribute_profile, internal_dict): @@ -190,10 +218,14 @@ def from_internal(self, attribute_profile, internal_dict): external_dict = {} for internal_attribute_name in internal_dict: try: - attribute_mapping = self.from_internal_attributes[internal_attribute_name] - except KeyError: - logline = "no attribute mapping found for the internal attribute {}".format( + attribute_mapping = self.from_internal_attributes[ internal_attribute_name + ] + except KeyError: + logline = ( + "no attribute mapping found for the internal attribute {}".format( + internal_attribute_name + ) ) logger.debug(logline) continue @@ -206,20 +238,29 @@ def from_internal(self, attribute_profile, internal_dict): logger.debug(logline) continue - external_attribute_names = self.from_internal_attributes[internal_attribute_name][attribute_profile] + external_attribute_names = self.from_internal_attributes[ + internal_attribute_name + ][attribute_profile] # select the first attribute name external_attribute_name = external_attribute_names[0] - logline = "frontend attribute {external} mapped from {internal} ({value})".format( - external=external_attribute_name, internal=internal_attribute_name, value=internal_dict[internal_attribute_name] + logline = ( + "frontend attribute {external} mapped from {internal} ({value})".format( + external=external_attribute_name, + internal=internal_attribute_name, + value=internal_dict[internal_attribute_name], + ) ) logger.debug(logline) if self.separator in external_attribute_name: nested_attribute_names = external_attribute_name.split(self.separator) - nested_dict = self._create_nested_attribute_value(nested_attribute_names[1:], - internal_dict[internal_attribute_name]) + nested_dict = self._create_nested_attribute_value( + nested_attribute_names[1:], internal_dict[internal_attribute_name] + ) external_dict[nested_attribute_names[0]] = nested_dict else: - external_dict[external_attribute_name] = internal_dict[internal_attribute_name] + external_dict[external_attribute_name] = internal_dict[ + internal_attribute_name + ] return external_dict diff --git a/src/satosa/backends/apple.py b/src/satosa/backends/apple.py index f7c1189ea..d1d2fdf24 100644 --- a/src/satosa/backends/apple.py +++ b/src/satosa/backends/apple.py @@ -1,15 +1,17 @@ """ Apple backend module. """ +import json import logging -from .openid_connect import OpenIDConnectBackend, STATE_KEY + +import requests from oic.oauth2.message import Message from oic.oic.message import AuthorizationResponse + import satosa.logging_util as lu -from ..exception import SATOSAAuthenticationError -import json -import requests +from ..exception import SATOSAAuthenticationError +from .openid_connect import STATE_KEY, OpenIDConnectBackend logger = logging.getLogger(__name__) diff --git a/src/satosa/backends/bitbucket.py b/src/satosa/backends/bitbucket.py index 6932ce901..dc44bc6f9 100644 --- a/src/satosa/backends/bitbucket.py +++ b/src/satosa/backends/bitbucket.py @@ -3,10 +3,10 @@ """ import json import logging -import requests -from oic.utils.authn.authn_context import UNSPECIFIED +import requests from oic.oauth2.consumer import stateID +from oic.utils.authn.authn_context import UNSPECIFIED from satosa.backends.oauth import _OAuthBackend from satosa.internal import AuthenticationInformation @@ -37,10 +37,17 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name): :type base_url: str :type name: str """ - config.setdefault('response_type', 'code') - config['verify_accesstoken_state'] = False - super().__init__(outgoing, internal_attributes, config, base_url, - name, 'bitbucket', 'account_id') + config.setdefault("response_type", "code") + config["verify_accesstoken_state"] = False + super().__init__( + outgoing, + internal_attributes, + config, + base_url, + name, + "bitbucket", + "account_id", + ) def get_request_args(self, get_state=stateID): request_args = super().get_request_args(get_state=get_state) @@ -58,28 +65,36 @@ def get_request_args(self, get_state=stateID): def auth_info(self, request): return AuthenticationInformation( - UNSPECIFIED, None, - self.config['server_info']['authorization_endpoint']) + UNSPECIFIED, None, self.config["server_info"]["authorization_endpoint"] + ) def user_information(self, access_token): - url = self.config['server_info']['user_endpoint'] + url = self.config["server_info"]["user_endpoint"] email_url = "{}/emails".format(url) - headers = {'Authorization': 'Bearer {}'.format(access_token)} + headers = {"Authorization": "Bearer {}".format(access_token)} resp = requests.get(url, headers=headers) data = json.loads(resp.text) - if 'email' in self.config['scope']: + if "email" in self.config["scope"]: resp = requests.get(email_url, headers=headers) emails = json.loads(resp.text) - data.update({ - 'email': [e for e in [d.get('email') - for d in emails.get('values') - if d.get('is_primary') - ] - ], - 'email_confirmed': [e for e in [d.get('email') - for d in emails.get('values') - if d.get('is_confirmed') - ] - ] - }) + data.update( + { + "email": [ + e + for e in [ + d.get("email") + for d in emails.get("values") + if d.get("is_primary") + ] + ], + "email_confirmed": [ + e + for e in [ + d.get("email") + for d in emails.get("values") + if d.get("is_confirmed") + ] + ], + } + ) return data diff --git a/src/satosa/backends/github.py b/src/satosa/backends/github.py index 70944e371..6b4f6ce0f 100644 --- a/src/satosa/backends/github.py +++ b/src/satosa/backends/github.py @@ -3,15 +3,14 @@ """ import json import logging -import requests -from oic.utils.authn.authn_context import UNSPECIFIED +import requests from oic.oauth2.consumer import stateID from oic.oauth2.message import AuthorizationResponse +from oic.utils.authn.authn_context import UNSPECIFIED from satosa.backends.oauth import _OAuthBackend -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.internal import AuthenticationInformation, InternalData from satosa.response import Redirect from satosa.util import rndstr @@ -39,11 +38,11 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name): :type base_url: str :type name: str """ - config.setdefault('response_type', 'code') - config['verify_accesstoken_state'] = False + config.setdefault("response_type", "code") + config["verify_accesstoken_state"] = False super().__init__( - outgoing, internal_attributes, config, base_url, name, 'github', - 'id') + outgoing, internal_attributes, config, base_url, name, "github", "id" + ) def start_auth(self, context, internal_request, get_state=stateID): """ @@ -58,53 +57,56 @@ def start_auth(self, context, internal_request, get_state=stateID): context.state[self.name] = dict(state=oauth_state) request_args = dict( - client_id=self.config['client_config']['client_id'], + client_id=self.config["client_config"]["client_id"], redirect_uri=self.redirect_url, state=oauth_state, - allow_signup=self.config.get('allow_signup', False)) - scope = ' '.join(self.config['scope']) + allow_signup=self.config.get("allow_signup", False), + ) + scope = " ".join(self.config["scope"]) if scope: - request_args['scope'] = scope + request_args["scope"] = scope - cis = self.consumer.construct_AuthorizationRequest( - request_args=request_args) + cis = self.consumer.construct_AuthorizationRequest(request_args=request_args) return Redirect(cis.request(self.consumer.authorization_endpoint)) def auth_info(self, requrest): return AuthenticationInformation( - UNSPECIFIED, None, - self.config['server_info']['authorization_endpoint']) + UNSPECIFIED, None, self.config["server_info"]["authorization_endpoint"] + ) def _authn_response(self, context): state_data = context.state[self.name] aresp = self.consumer.parse_response( - AuthorizationResponse, info=json.dumps(context.request)) + AuthorizationResponse, info=json.dumps(context.request) + ) self._verify_state(aresp, state_data, context.state) - url = self.config['server_info']['token_endpoint'] + url = self.config["server_info"]["token_endpoint"] data = dict( - code=aresp['code'], + code=aresp["code"], redirect_uri=self.redirect_url, - client_id=self.config['client_config']['client_id'], - client_secret=self.config['client_secret'], ) - headers = {'Accept': 'application/json'} + client_id=self.config["client_config"]["client_id"], + client_secret=self.config["client_secret"], + ) + headers = {"Accept": "application/json"} r = requests.post(url, data=data, headers=headers) response = r.json() - if self.config.get('verify_accesstoken_state', True): + if self.config.get("verify_accesstoken_state", True): self._verify_state(response, state_data, context.state) user_info = self.user_information(response["access_token"]) auth_info = self.auth_info(context.request) internal_response = InternalData(auth_info=auth_info) internal_response.attributes = self.converter.to_internal( - self.external_type, user_info) + self.external_type, user_info + ) internal_response.subject_id = str(user_info[self.user_id_attr]) return self.auth_callback_func(context, internal_response) def user_information(self, access_token): - url = self.config['server_info']['user_info'] - headers = {'Authorization': 'token {}'.format(access_token)} + url = self.config["server_info"]["user_info"] + headers = {"Authorization": "token {}".format(access_token)} r = requests.get(url, headers=headers) ret = r.json() - ret['id'] = str(ret['id']) + ret["id"] = str(ret["id"]) return ret diff --git a/src/satosa/backends/idpy_oidc.py b/src/satosa/backends/idpy_oidc.py index f3ea43f61..16b4cd4f3 100644 --- a/src/satosa/backends/idpy_oidc.py +++ b/src/satosa/backends/idpy_oidc.py @@ -8,14 +8,12 @@ from idpyoidc.client.oauth2.stand_alone_client import StandAloneClient from idpyoidc.server.user_authn.authn_context import UNSPECIFIED -from satosa.backends.base import BackendModule -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData import satosa.logging_util as lu -from ..exception import SATOSAAuthenticationError -from ..exception import SATOSAError -from ..response import Redirect +from satosa.backends.base import BackendModule +from satosa.internal import AuthenticationInformation, InternalData +from ..exception import SATOSAAuthenticationError, SATOSAError +from ..response import Redirect UTC = datetime.timezone.utc logger = logging.getLogger(__name__) @@ -52,7 +50,7 @@ def __init__(self, auth_callback_func, internal_attributes, config, base_url, na self.client.do_provider_info() self.client.do_client_registration() - _redirect_uris = self.client.context.claims.get_usage('redirect_uris') + _redirect_uris = self.client.context.claims.get_usage("redirect_uris") if not _redirect_uris: raise SATOSAError("Missing path in redirect uri") self.redirect_path = urlparse(_redirect_uris[0]).path @@ -94,12 +92,14 @@ def response_endpoint(self, context, *args): _info = self.client.finalize(context.request) self._check_error_response(_info, context) - userinfo = _info.get('userinfo') - id_token = _info.get('id_token') + userinfo = _info.get("userinfo") + id_token = _info.get("id_token") if not id_token and not userinfo: msg = "No id_token or userinfo, nothing to do.." - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) raise SATOSAAuthenticationError(context.state, "No user info available.") @@ -151,6 +151,8 @@ def _check_error_response(self, response, context): error=response["error"], description=response.get("error_description", ""), ) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) raise SATOSAAuthenticationError(context.state, "Access denied") diff --git a/src/satosa/backends/linkedin.py b/src/satosa/backends/linkedin.py index 8d3a85b4c..095d384bf 100644 --- a/src/satosa/backends/linkedin.py +++ b/src/satosa/backends/linkedin.py @@ -3,19 +3,17 @@ """ import json import logging -import requests -from oic.utils.authn.authn_context import UNSPECIFIED +import requests from oic.oauth2.consumer import stateID from oic.oauth2.message import AuthorizationResponse +from oic.utils.authn.authn_context import UNSPECIFIED from satosa.backends.oauth import _OAuthBackend -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.internal import AuthenticationInformation, InternalData from satosa.response import Redirect from satosa.util import rndstr - logger = logging.getLogger(__name__) @@ -40,11 +38,11 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name): :type base_url: str :type name: str """ - config.setdefault('response_type', 'code') - config['verify_accesstoken_state'] = False + config.setdefault("response_type", "code") + config["verify_accesstoken_state"] = False super().__init__( - outgoing, internal_attributes, config, base_url, name, 'linkedin', - 'id') + outgoing, internal_attributes, config, base_url, name, "linkedin", "id" + ) def start_auth(self, context, internal_request, get_state=stateID): """ @@ -59,61 +57,66 @@ def start_auth(self, context, internal_request, get_state=stateID): context.state[self.name] = dict(state=oauth_state) request_args = dict( - response_type='code', - client_id=self.config['client_config']['client_id'], + response_type="code", + client_id=self.config["client_config"]["client_id"], redirect_uri=self.redirect_url, - state=oauth_state) - scope = ' '.join(self.config['scope']) + state=oauth_state, + ) + scope = " ".join(self.config["scope"]) if scope: - request_args['scope'] = scope + request_args["scope"] = scope - cis = self.consumer.construct_AuthorizationRequest( - request_args=request_args) + cis = self.consumer.construct_AuthorizationRequest(request_args=request_args) return Redirect(cis.request(self.consumer.authorization_endpoint)) def auth_info(self, requrest): return AuthenticationInformation( - UNSPECIFIED, None, - self.config['server_info']['authorization_endpoint']) + UNSPECIFIED, None, self.config["server_info"]["authorization_endpoint"] + ) def _authn_response(self, context): state_data = context.state[self.name] aresp = self.consumer.parse_response( - AuthorizationResponse, info=json.dumps(context.request)) + AuthorizationResponse, info=json.dumps(context.request) + ) self._verify_state(aresp, state_data, context.state) - url = self.config['server_info']['token_endpoint'] + url = self.config["server_info"]["token_endpoint"] data = dict( - grant_type='authorization_code', - code=aresp['code'], + grant_type="authorization_code", + code=aresp["code"], redirect_uri=self.redirect_url, - client_id=self.config['client_config']['client_id'], - client_secret=self.config['client_secret']) + client_id=self.config["client_config"]["client_id"], + client_secret=self.config["client_secret"], + ) r = requests.post(url, data=data) response = r.json() - if self.config.get('verify_accesstoken_state', True): + if self.config.get("verify_accesstoken_state", True): self._verify_state(response, state_data, context.state) auth_info = self.auth_info(context.request) - user_email_response = self.user_information(response["access_token"], 'email_info') - user_info = self.user_information(response["access_token"], 'user_info') + user_email_response = self.user_information( + response["access_token"], "email_info" + ) + user_info = self.user_information(response["access_token"], "user_info") user_email = { "emailAddress": [ - element['handle~']['emailAddress'] - for element in user_email_response['elements'] + element["handle~"]["emailAddress"] + for element in user_email_response["elements"] ] } user_info.update(user_email) internal_response = InternalData(auth_info=auth_info) internal_response.attributes = self.converter.to_internal( - self.external_type, user_info) + self.external_type, user_info + ) internal_response.subject_id = user_info[self.user_id_attr] return self.auth_callback_func(context, internal_response) def user_information(self, access_token, api): - url = self.config['server_info'][api] - headers = {'Authorization': 'Bearer {}'.format(access_token)} + url = self.config["server_info"][api] + headers = {"Authorization": "Bearer {}".format(access_token)} r = requests.get(url, headers=headers) return r.json() diff --git a/src/satosa/backends/oauth.py b/src/satosa/backends/oauth.py index 3e2bd041b..3207f9518 100644 --- a/src/satosa/backends/oauth.py +++ b/src/satosa/backends/oauth.py @@ -6,22 +6,22 @@ from base64 import urlsafe_b64encode import requests - from oic.oauth2.consumer import Consumer, stateID from oic.oauth2.message import AuthorizationResponse from oic.utils.authn.authn_context import UNSPECIFIED import satosa.logging_util as lu -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.backends.base import BackendModule from satosa.exception import SATOSAAuthenticationError -from satosa.response import Redirect -from satosa.util import rndstr +from satosa.internal import AuthenticationInformation, InternalData from satosa.metadata_creation.description import ( - OrganizationDesc, UIInfoDesc, ContactPersonDesc, MetadataDescription + ContactPersonDesc, + MetadataDescription, + OrganizationDesc, + UIInfoDesc, ) -from satosa.backends.base import BackendModule - +from satosa.response import Redirect +from satosa.util import rndstr logger = logging.getLogger(__name__) @@ -32,7 +32,16 @@ class _OAuthBackend(BackendModule): See satosa.backends.oauth.FacebookBackend. """ - def __init__(self, outgoing, internal_attributes, config, base_url, name, external_type, user_id_attr): + def __init__( + self, + outgoing, + internal_attributes, + config, + base_url, + name, + external_type, + user_id_attr, + ): """ :param outgoing: Callback should be called by the module after the authorization in the backend is done. @@ -54,7 +63,10 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name, extern """ super().__init__(outgoing, internal_attributes, base_url, name) self.config = config - self.redirect_url = "%s/%s" % (self.config["base_url"], self.config["authz_page"]) + self.redirect_url = "%s/%s" % ( + self.config["base_url"], + self.config["authz_page"], + ) self.external_type = external_type self.user_id_attr = user_id_attr self.consumer = Consumer( @@ -62,7 +74,8 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name, extern client_config=self.config["client_config"], server_info=self.config["server_info"], authz_page=self.config["authz_page"], - response_type=self.config["response_type"]) + response_type=self.config["response_type"], + ) self.consumer.client_secret = self.config["client_secret"] def start_auth(self, context, internal_request, get_state=stateID): @@ -110,15 +123,19 @@ def _verify_state(self, resp, state_data, state): :param state: The current state for the proxy and this backend. Only used for raising errors. """ - is_known_state = "state" in resp and "state" in state_data and resp["state"] == state_data["state"] + is_known_state = ( + "state" in resp + and "state" in state_data + and resp["state"] == state_data["state"] + ) if not is_known_state: received_state = resp.get("state", "") msg = "Missing or invalid state [{}] in response!".format(received_state) logline = lu.LOG_FMT.format(id=lu.get_session_id(state), message=msg) logger.debug(logline) - raise SATOSAAuthenticationError(state, - "Missing or invalid state [%s] in response!" % - received_state) + raise SATOSAAuthenticationError( + state, "Missing or invalid state [%s] in response!" % received_state + ) def _authn_response(self, context): """ @@ -131,19 +148,31 @@ def _authn_response(self, context): which generates the Response object. """ state_data = context.state[self.name] - aresp = self.consumer.parse_response(AuthorizationResponse, info=json.dumps(context.request)) + aresp = self.consumer.parse_response( + AuthorizationResponse, info=json.dumps(context.request) + ) self._verify_state(aresp, state_data, context.state) - rargs = {"code": aresp["code"], "redirect_uri": self.redirect_url, - "state": state_data["state"]} + rargs = { + "code": aresp["code"], + "redirect_uri": self.redirect_url, + "state": state_data["state"], + } - atresp = self.consumer.do_access_token_request(request_args=rargs, state=aresp["state"]) - if "verify_accesstoken_state" not in self.config or self.config["verify_accesstoken_state"]: + atresp = self.consumer.do_access_token_request( + request_args=rargs, state=aresp["state"] + ) + if ( + "verify_accesstoken_state" not in self.config + or self.config["verify_accesstoken_state"] + ): self._verify_state(atresp, state_data, context.state) user_info = self.user_information(atresp["access_token"]) internal_response = InternalData(auth_info=self.auth_info(context.request)) - internal_response.attributes = self.converter.to_internal(self.external_type, user_info) + internal_response.attributes = self.converter.to_internal( + self.external_type, user_info + ) internal_response.subject_id = user_info[self.user_id_attr] return self.auth_callback_func(context, internal_response) @@ -156,7 +185,9 @@ def auth_info(self, request): :param request: The request parameters in the authentication response sent by the AS. :return: How, who and when the autentication took place. """ - raise NotImplementedError("Method 'auth_info' must be implemented in the subclass!") + raise NotImplementedError( + "Method 'auth_info' must be implemented in the subclass!" + ) def user_information(self, access_token): """ @@ -167,7 +198,9 @@ def user_information(self, access_token): :param access_token: The access token to be used to retrieve the data. :return: Dictionary with attribute name as key and attribute value as value. """ - raise NotImplementedError("Method 'user_information' must be implemented in the subclass!") + raise NotImplementedError( + "Method 'user_information' must be implemented in the subclass!" + ) def get_metadata_desc(self): """ @@ -175,7 +208,8 @@ def get_metadata_desc(self): :rtype: satosa.metadata_creation.description.MetadataDescription """ return get_metadata_desc_for_oauth_backend( - self.config["server_info"]["authorization_endpoint"], self.config) + self.config["server_info"]["authorization_endpoint"], self.config + ) class FacebookBackend(_OAuthBackend): @@ -210,7 +244,9 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name): """ config.setdefault("response_type", "code") config["verify_accesstoken_state"] = False - super().__init__(outgoing, internal_attributes, config, base_url, name, "facebook", "id") + super().__init__( + outgoing, internal_attributes, config, base_url, name, "facebook", "id" + ) def get_request_args(self, get_state=stateID): request_args = super().get_request_args(get_state=get_state) @@ -235,9 +271,9 @@ def auth_info(self, request): :param request: The request parameters in the authentication response sent by the AS. :return: How, who and when the autentication took place. """ - auth_info = AuthenticationInformation(UNSPECIFIED, - None, - self.config["server_info"]["authorization_endpoint"]) + auth_info = AuthenticationInformation( + UNSPECIFIED, None, self.config["server_info"]["authorization_endpoint"] + ) return auth_info def user_information(self, access_token): @@ -250,7 +286,9 @@ def user_information(self, access_token): :return: Dictionary with attribute name as key and attribute value as value. """ payload = {"access_token": access_token} - url = self.config["server_info"].get("graph_endpoint", self.DEFAULT_GRAPH_ENDPOINT) + url = self.config["server_info"].get( + "graph_endpoint", self.DEFAULT_GRAPH_ENDPOINT + ) if self.config["fields"]: payload["fields"] = ",".join(self.config["fields"]) resp = requests.get(url, params=payload) @@ -301,8 +339,12 @@ def get_metadata_desc_for_oauth_backend(entity_id, config): for name_info in organization_info.get("organization_name", []): organization.add_name(name_info[0], name_info[1]) - for display_name_info in organization_info.get("organization_display_name", []): - organization.add_display_name(display_name_info[0], display_name_info[1]) + for display_name_info in organization_info.get( + "organization_display_name", [] + ): + organization.add_display_name( + display_name_info[0], display_name_info[1] + ) for url_info in organization_info.get("organization_url", []): organization.add_url(url_info[0], url_info[1]) @@ -317,11 +359,17 @@ def get_metadata_desc_for_oauth_backend(entity_id, config): for name in ui_info.get("display_name", []): ui_description.add_display_name(name[0], name[1]) for logo in ui_info.get("logo", []): - ui_description.add_logo(logo["image"], logo["width"], logo["height"], logo["lang"]) + ui_description.add_logo( + logo["image"], logo["width"], logo["height"], logo["lang"] + ) for keywords in ui_info.get("keywords", []): - ui_description.add_keywords(keywords.get("text", []), keywords.get("lang")) + ui_description.add_keywords( + keywords.get("text", []), keywords.get("lang") + ) for information_url in ui_info.get("information_url", []): - ui_description.add_information_url(information_url.get("text"), information_url.get("lang")) + ui_description.add_information_url( + information_url.get("text"), information_url.get("lang") + ) for privacy_statement_url in ui_info.get("privacy_statement_url", []): ui_description.add_information_url( privacy_statement_url.get("text"), privacy_statement_url.get("lang") diff --git a/src/satosa/backends/openid_connect.py b/src/satosa/backends/openid_connect.py index 58d47af9b..80d363a88 100644 --- a/src/satosa/backends/openid_connect.py +++ b/src/satosa/backends/openid_connect.py @@ -5,25 +5,23 @@ from datetime import datetime from urllib.parse import urlparse -from oic import oic -from oic import rndstr -from oic.oic.message import AuthorizationResponse -from oic.oic.message import ProviderConfigurationResponse -from oic.oic.message import RegistrationRequest +from oic import oic, rndstr +from oic.oic.message import ( + AuthorizationResponse, + ProviderConfigurationResponse, + RegistrationRequest, +) from oic.utils.authn.authn_context import UNSPECIFIED from oic.utils.authn.client import CLIENT_AUTHN_METHOD from oic.utils.settings import PyoidcSettings import satosa.logging_util as lu -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.internal import AuthenticationInformation, InternalData + +from ..exception import SATOSAAuthenticationError, SATOSAError, SATOSAMissingStateError +from ..response import Redirect from .base import BackendModule from .oauth import get_metadata_desc_for_oauth_backend -from ..exception import SATOSAAuthenticationError -from ..exception import SATOSAError -from ..exception import SATOSAMissingStateError -from ..response import Redirect - logger = logging.getLogger(__name__) @@ -71,10 +69,12 @@ def __init__(self, auth_callback_func, internal_attributes, config, base_url, na msg = { "message": f"Failed to initialize client", "error": str(exc), - "client_metadata": self.config['client']['client_metadata'], - "provider_metadata": self.config['provider_metadata'], + "client_metadata": self.config["client"]["client_metadata"], + "provider_metadata": self.config["provider_metadata"], } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) raise SATOSAAuthenticationError(context.state, msg) from exc @@ -91,10 +91,7 @@ def start_auth(self, context, request_info): """ oidc_nonce = rndstr() oidc_state = rndstr() - state_data = { - NONCE_KEY: oidc_nonce, - STATE_KEY: oidc_state - } + state_data = {NONCE_KEY: oidc_nonce, STATE_KEY: oidc_state} context.state[self.name] = state_data args = { @@ -103,7 +100,7 @@ def start_auth(self, context, request_info): "client_id": self.client.client_id, "redirect_uri": self.client.registration_response["redirect_uris"][0], "state": oidc_state, - "nonce": oidc_nonce + "nonce": oidc_nonce, } args.update(self.config["client"]["auth_req_params"]) auth_req = self.client.construct_AuthorizationRequest(request_args=args) @@ -119,7 +116,9 @@ def register_endpoints(self): :return: A list that can be used to map the request to SATOSA to this endpoint. """ url_map = [] - redirect_path = urlparse(self.config["client"]["client_metadata"]["redirect_uris"][0]).path + redirect_path = urlparse( + self.config["client"]["client_metadata"]["redirect_uris"][0] + ).path if not redirect_path: raise SATOSAError("Missing path in redirect uri") @@ -137,10 +136,16 @@ def _verify_nonce(self, nonce, context): """ backend_state = context.state[self.name] if nonce != backend_state[NONCE_KEY]: - msg = "Missing or invalid nonce in authn response for state: {}".format(backend_state) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + msg = "Missing or invalid nonce in authn response for state: {}".format( + backend_state + ) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) - raise SATOSAAuthenticationError(context.state, "Missing or invalid nonce in authn response") + raise SATOSAAuthenticationError( + context.state, "Missing or invalid nonce in authn response" + ) def _get_tokens(self, authn_response, context): """ @@ -153,13 +158,17 @@ def _get_tokens(self, authn_response, context): # make token request args = { "code": authn_response["code"], - "redirect_uri": self.client.registration_response['redirect_uris'][0], + "redirect_uri": self.client.registration_response["redirect_uris"][0], } - token_resp = self.client.do_access_token_request(scope="openid", state=authn_response["state"], - request_args=args, - authn_method=self.client.registration_response[ - "token_endpoint_auth_method"]) + token_resp = self.client.do_access_token_request( + scope="openid", + state=authn_response["state"], + request_args=args, + authn_method=self.client.registration_response[ + "token_endpoint_auth_method" + ], + ) self._check_error_response(token_resp, context) return token_resp["access_token"], token_resp["id_token"] @@ -179,7 +188,9 @@ def _check_error_response(self, response, context): error=response["error"], description=response.get("error_description", ""), ) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) raise SATOSAAuthenticationError(context.state, "Access denied") @@ -217,12 +228,20 @@ def response_endpoint(self, context, *args): raise SATOSAMissingStateError(error) backend_state = context.state[self.name] - authn_resp = self.client.parse_response(AuthorizationResponse, info=context.request, sformat="dict") + authn_resp = self.client.parse_response( + AuthorizationResponse, info=context.request, sformat="dict" + ) if backend_state[STATE_KEY] != authn_resp["state"]: - msg = "Missing or invalid state in authn response for state: {}".format(backend_state) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + msg = "Missing or invalid state in authn response for state: {}".format( + backend_state + ) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) - raise SATOSAAuthenticationError(context.state, "Missing or invalid state in authn response") + raise SATOSAAuthenticationError( + context.state, "Missing or invalid state in authn response" + ) self._check_error_response(authn_resp, context) access_token, id_token_claims = self._get_tokens(authn_resp, context) @@ -238,7 +257,9 @@ def response_endpoint(self, context, *args): if not id_token_claims and not userinfo: msg = "No id_token or userinfo, nothing to do.." - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) raise SATOSAAuthenticationError(context.state, "No user info available.") @@ -246,7 +267,9 @@ def response_endpoint(self, context, *args): msg = "UserInfo: {}".format(all_user_claims) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.debug(logline) - internal_resp = self._translate_response(all_user_claims, self.client.authorization_endpoint) + internal_resp = self._translate_response( + all_user_claims, self.client.authorization_endpoint + ) return self.auth_callback_func(context, internal_resp) def _translate_response(self, response, issuer): @@ -273,7 +296,9 @@ def get_metadata_desc(self): See satosa.backends.oauth.get_metadata_desc :rtype: satosa.metadata_creation.description.MetadataDescription """ - return get_metadata_desc_for_oauth_backend(self.config["provider_metadata"]["issuer"], self.config) + return get_metadata_desc_for_oauth_backend( + self.config["provider_metadata"]["issuer"], self.config + ) def _create_client(provider_metadata, client_metadata, settings=None): @@ -286,15 +311,15 @@ def _create_client(provider_metadata, client_metadata, settings=None): :return: client instance to use for communicating with the configured provider :rtype: oic.oic.Client """ - client = oic.Client( - client_authn_method=CLIENT_AUTHN_METHOD, settings=settings - ) + client = oic.Client(client_authn_method=CLIENT_AUTHN_METHOD, settings=settings) # Provider configuration information if "authorization_endpoint" in provider_metadata: # no dynamic discovery necessary - client.handle_provider_config(ProviderConfigurationResponse(**provider_metadata), - provider_metadata["issuer"]) + client.handle_provider_config( + ProviderConfigurationResponse(**provider_metadata), + provider_metadata["issuer"], + ) else: # do dynamic discovery client.provider_config(provider_metadata["issuer"]) @@ -305,9 +330,12 @@ def _create_client(provider_metadata, client_metadata, settings=None): client.store_registration_info(RegistrationRequest(**client_metadata)) else: # do dynamic registration - client.register(client.provider_info['registration_endpoint'], - **client_metadata) + client.register( + client.provider_info["registration_endpoint"], **client_metadata + ) - client.subject_type = (client.registration_response.get("subject_type") or - client.provider_info["subject_types_supported"][0]) + client.subject_type = ( + client.registration_response.get("subject_type") + or client.provider_info["subject_types_supported"][0] + ) return client diff --git a/src/satosa/backends/orcid.py b/src/satosa/backends/orcid.py index 649e72451..26aa9688f 100644 --- a/src/satosa/backends/orcid.py +++ b/src/satosa/backends/orcid.py @@ -2,17 +2,16 @@ OAuth backend for Orcid """ import json -import requests import logging from urllib.parse import urljoin -from oic.utils.authn.authn_context import UNSPECIFIED +import requests from oic.oauth2.consumer import stateID from oic.oauth2.message import AuthorizationResponse +from oic.utils.authn.authn_context import UNSPECIFIED from satosa.backends.oauth import _OAuthBackend -from satosa.internal import InternalData -from satosa.internal import AuthenticationInformation +from satosa.internal import AuthenticationInformation, InternalData from satosa.util import rndstr logger = logging.getLogger(__name__) @@ -39,67 +38,74 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name): :type base_url: str :type name: str """ - config.setdefault('response_type', 'code') - config['verify_accesstoken_state'] = False + config.setdefault("response_type", "code") + config["verify_accesstoken_state"] = False super().__init__( - outgoing, internal_attributes, config, base_url, name, 'orcid', - 'orcid') + outgoing, internal_attributes, config, base_url, name, "orcid", "orcid" + ) def get_request_args(self, get_state=stateID): oauth_state = get_state(self.config["base_url"], rndstr().encode()) request_args = { - "client_id": self.config['client_config']['client_id'], + "client_id": self.config["client_config"]["client_id"], "redirect_uri": self.redirect_url, - "scope": ' '.join(self.config['scope']), + "scope": " ".join(self.config["scope"]), "state": oauth_state, } return request_args def auth_info(self, requrest): return AuthenticationInformation( - UNSPECIFIED, None, - self.config['server_info']['authorization_endpoint']) + UNSPECIFIED, None, self.config["server_info"]["authorization_endpoint"] + ) def _authn_response(self, context): state_data = context.state[self.name] aresp = self.consumer.parse_response( - AuthorizationResponse, info=json.dumps(context.request)) + AuthorizationResponse, info=json.dumps(context.request) + ) self._verify_state(aresp, state_data, context.state) - rargs = {"code": aresp["code"], "redirect_uri": self.redirect_url, - "state": state_data["state"]} + rargs = { + "code": aresp["code"], + "redirect_uri": self.redirect_url, + "state": state_data["state"], + } atresp = self.consumer.do_access_token_request( - request_args=rargs, state=aresp['state']) + request_args=rargs, state=aresp["state"] + ) user_info = self.user_information( - atresp['access_token'], atresp['orcid'], atresp.get('name')) - internal_response = InternalData( - auth_info=self.auth_info(context.request)) + atresp["access_token"], atresp["orcid"], atresp.get("name") + ) + internal_response = InternalData(auth_info=self.auth_info(context.request)) internal_response.attributes = self.converter.to_internal( - self.external_type, user_info) + self.external_type, user_info + ) internal_response.subject_id = user_info[self.user_id_attr] return self.auth_callback_func(context, internal_response) def user_information(self, access_token, orcid, name=None): - base_url = self.config['server_info']['user_info'] - url = urljoin(base_url, '{}/person'.format(orcid)) + base_url = self.config["server_info"]["user_info"] + url = urljoin(base_url, "{}/person".format(orcid)) headers = { - 'Accept': 'application/orcid+json', - 'Authorization': "Bearer {}".format(access_token) + "Accept": "application/orcid+json", + "Authorization": "Bearer {}".format(access_token), } r = requests.get(url, headers=headers) r = r.json() - emails, addresses = r['emails']['email'], r['addresses']['address'] - rname = r.get('name') or {} + emails, addresses = r["emails"]["email"], r["addresses"]["address"] + rname = r.get("name") or {} ret = dict( - address=', '.join([e['country']['value'] for e in addresses]), + address=", ".join([e["country"]["value"] for e in addresses]), displayname=name, - edupersontargetedid=orcid, orcid=orcid, - mail=' '.join([e['email'] for e in emails]), + edupersontargetedid=orcid, + orcid=orcid, + mail=" ".join([e["email"] for e in emails]), name=name, - givenname=(rname.get('given-names') or {}).get('value'), - surname=(rname.get('family-name') or {}).get('value'), + givenname=(rname.get("given-names") or {}).get("value"), + surname=(rname.get("family-name") or {}).get("value"), ) return ret diff --git a/src/satosa/backends/reflector.py b/src/satosa/backends/reflector.py index 6a9055485..89911e658 100644 --- a/src/satosa/backends/reflector.py +++ b/src/satosa/backends/reflector.py @@ -4,10 +4,9 @@ import base64 from datetime import datetime -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData -from satosa.metadata_creation.description import MetadataDescription from satosa.backends.base import BackendModule +from satosa.internal import AuthenticationInformation, InternalData +from satosa.metadata_creation.description import MetadataDescription class ReflectorBackend(BackendModule): diff --git a/src/satosa/backends/saml2.py b/src/satosa/backends/saml2.py index 8be4572d4..33e2d2450 100644 --- a/src/satosa/backends/saml2.py +++ b/src/satosa/backends/saml2.py @@ -10,40 +10,40 @@ from urllib.parse import urlparse from saml2 import BINDING_HTTP_REDIRECT +from saml2.authn_context import requested_authn_context from saml2.client import Saml2Client from saml2.config import SPConfig from saml2.extension.mdui import NAMESPACE as UI_NAMESPACE from saml2.metadata import create_metadata_string -from saml2.authn_context import requested_authn_context -from saml2.samlp import RequesterID -from saml2.samlp import Scoping +from saml2.samlp import RequesterID, Scoping import satosa.logging_util as lu import satosa.util as util -from satosa.base import SAMLBaseModule -from satosa.base import SAMLEIDASBaseModule +from satosa.backends.base import BackendModule from satosa.base import STATE_KEY as STATE_KEY_BASE +from satosa.base import SAMLBaseModule, SAMLEIDASBaseModule from satosa.context import Context -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData -from satosa.exception import SATOSAAuthenticationError -from satosa.exception import SATOSAMissingStateError -from satosa.exception import SATOSAAuthenticationFlowError -from satosa.response import SeeOther, Response -from satosa.saml_util import make_saml_response +from satosa.exception import ( + SATOSAAuthenticationError, + SATOSAAuthenticationFlowError, + SATOSAMissingStateError, +) +from satosa.internal import AuthenticationInformation, InternalData from satosa.metadata_creation.description import ( - MetadataDescription, OrganizationDesc, ContactPersonDesc, UIInfoDesc + ContactPersonDesc, + MetadataDescription, + OrganizationDesc, + UIInfoDesc, ) -from satosa.backends.base import BackendModule - +from satosa.response import Response, SeeOther +from satosa.saml_util import make_saml_response logger = logging.getLogger(__name__) def get_memorized_idp(context, config, force_authn): - memorized_idp = ( - config.get(SAMLBackend.KEY_MEMORIZE_IDP) - and context.state.get(Context.KEY_MEMORIZED_IDP) + memorized_idp = config.get(SAMLBackend.KEY_MEMORIZE_IDP) and context.state.get( + Context.KEY_MEMORIZED_IDP ) use_when_force_authn = config.get( SAMLBackend.KEY_USE_MEMORIZED_IDP_WHEN_FORCE_AUTHN @@ -67,9 +67,10 @@ def get_force_authn(context, config, sp_config): """ mirror = config.get(SAMLBackend.KEY_MIRROR_FORCE_AUTHN) from_state = mirror and context.state.get(Context.KEY_FORCE_AUTHN) - from_context = ( - mirror and context.get_decoration(Context.KEY_FORCE_AUTHN) in ["true", "1"] - ) + from_context = mirror and context.get_decoration(Context.KEY_FORCE_AUTHN) in [ + "true", + "1", + ] from_config = sp_config.getattr("force_authn", "sp") is_set = str(from_state or from_context or from_config).lower() == "true" value = "true" if is_set else None @@ -80,17 +81,18 @@ class SAMLBackend(BackendModule, SAMLBaseModule): """ A saml2 backend module (acting as a SP). """ - KEY_DISCO_SRV = 'disco_srv' - KEY_SAML_DISCOVERY_SERVICE_URL = 'saml_discovery_service_url' - KEY_SAML_DISCOVERY_SERVICE_POLICY = 'saml_discovery_service_policy' - KEY_SP_CONFIG = 'sp_config' - KEY_SEND_REQUESTER_ID = 'send_requester_id' - KEY_MIRROR_FORCE_AUTHN = 'mirror_force_authn' - KEY_IS_PASSIVE = 'is_passive' - KEY_MEMORIZE_IDP = 'memorize_idp' - KEY_USE_MEMORIZED_IDP_WHEN_FORCE_AUTHN = 'use_memorized_idp_when_force_authn' - - VALUE_ACR_COMPARISON_DEFAULT = 'exact' + + KEY_DISCO_SRV = "disco_srv" + KEY_SAML_DISCOVERY_SERVICE_URL = "saml_discovery_service_url" + KEY_SAML_DISCOVERY_SERVICE_POLICY = "saml_discovery_service_policy" + KEY_SP_CONFIG = "sp_config" + KEY_SEND_REQUESTER_ID = "send_requester_id" + KEY_MIRROR_FORCE_AUTHN = "mirror_force_authn" + KEY_IS_PASSIVE = "is_passive" + KEY_MEMORIZE_IDP = "memorize_idp" + KEY_USE_MEMORIZED_IDP_WHEN_FORCE_AUTHN = "use_memorized_idp_when_force_authn" + + VALUE_ACR_COMPARISON_DEFAULT = "exact" def __init__(self, outgoing, internal_attributes, config, base_url, name): """ @@ -114,7 +116,7 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name): self.discosrv = config.get(SAMLBackend.KEY_DISCO_SRV) self.encryption_keys = [] self.outstanding_queries = {} - self.idp_blacklist_file = config.get('idp_blacklist_file', None) + self.idp_blacklist_file = config.get("idp_blacklist_file", None) sp_config = SPConfig().load(copy.deepcopy(config[SAMLBackend.KEY_SP_CONFIG])) @@ -122,20 +124,20 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name): # else, if key_file and cert_file are defined, use them for decryption # otherwise, do not use any decryption key. # ensure the choice is reflected back in the configuration. - sp_conf_encryption_keypairs = sp_config.getattr('encryption_keypairs', '') - sp_conf_key_file = sp_config.getattr('key_file', '') - sp_conf_cert_file = sp_config.getattr('cert_file', '') + sp_conf_encryption_keypairs = sp_config.getattr("encryption_keypairs", "") + sp_conf_key_file = sp_config.getattr("key_file", "") + sp_conf_cert_file = sp_config.getattr("cert_file", "") sp_keypairs = ( sp_conf_encryption_keypairs if sp_conf_encryption_keypairs - else [{'key_file': sp_conf_key_file, 'cert_file': sp_conf_cert_file}] + else [{"key_file": sp_conf_key_file, "cert_file": sp_conf_cert_file}] if sp_conf_key_file and sp_conf_cert_file else [] ) - sp_config.setattr('', 'encryption_keypairs', sp_keypairs) + sp_config.setattr("", "encryption_keypairs", sp_keypairs) # load the encryption keys - key_file_paths = [pair['key_file'] for pair in sp_keypairs] + key_file_paths = [pair["key_file"] for pair in sp_keypairs] for p in key_file_paths: with open(p) as key_file: self.encryption_keys.append(key_file.read()) @@ -154,8 +156,7 @@ def get_idp_entity_id(self, context): idps = self.sp.metadata.identity_providers() only_one_idp_in_metadata = ( - "mdq" not in self.config["sp_config"]["metadata"] - and len(idps) == 1 + "mdq" not in self.config["sp_config"]["metadata"] and len(idps) == 1 ) only_idp = only_one_idp_in_metadata and idps[0] @@ -227,19 +228,15 @@ def disco_query(self, context): disco_url, self.sp.config.entityid, **args ) - msg = { - "message": "Sending user to the discovery service", - "disco_url": loc - } + msg = {"message": "Sending user to the discovery service", "disco_url": loc} logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.info(logline) return SeeOther(loc) def construct_requested_authn_context(self, entity_id, *, target_accr=None): - acr_entry = ( - target_accr - or util.get_dict_defaults(self.acr_mapping or {}, entity_id) + acr_entry = target_accr or util.get_dict_defaults( + self.acr_mapping or {}, entity_id ) if not acr_entry: return None @@ -251,9 +248,8 @@ def construct_requested_authn_context(self, entity_id, *, target_accr=None): } authn_context = requested_authn_context( - acr_entry['class_ref'], comparison=acr_entry.get( - 'comparison', self.VALUE_ACR_COMPARISON_DEFAULT - ) + acr_entry["class_ref"], + comparison=acr_entry.get("comparison", self.VALUE_ACR_COMPARISON_DEFAULT), ) return authn_context @@ -276,19 +272,23 @@ def authn_request(self, context, entity_id): # stop here if self.idp_blacklist_file: with open(self.idp_blacklist_file) as blacklist_file: - blacklist_array = json.load(blacklist_file)['blacklist'] + blacklist_array = json.load(blacklist_file)["blacklist"] if entity_id in blacklist_array: msg = { "message": "AuthnRequest Failed", "error": f"Selected IdP with EntityID {entity_id} is blacklisted for this backend", } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) raise SATOSAAuthenticationError(context.state, msg) kwargs = {} target_accr = context.state.get(Context.KEY_TARGET_AUTHN_CONTEXT_CLASS_REF) - authn_context = self.construct_requested_authn_context(entity_id, target_accr=target_accr) + authn_context = self.construct_requested_authn_context( + entity_id, target_accr=target_accr + ) if authn_context: kwargs["requested_authn_context"] = authn_context if self.config.get(SAMLBackend.KEY_MIRROR_FORCE_AUTHN): @@ -296,7 +296,7 @@ def authn_request(self, context, entity_id): context, self.config, self.sp.config ) if self.config.get(SAMLBackend.KEY_SEND_REQUESTER_ID): - requester = context.state.state_dict[STATE_KEY_BASE]['requester'] + requester = context.state.state_dict[STATE_KEY_BASE]["requester"] kwargs["scoping"] = Scoping(requester_id=[RequesterID(text=requester)]) if self.config.get(SAMLBackend.KEY_IS_PASSIVE): kwargs["is_passive"] = "true" @@ -316,17 +316,21 @@ def authn_request(self, context, entity_id): "message": "AuthnRequest Failed", "error": f"Failed to construct the AuthnRequest for state: {e}", } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) raise SATOSAAuthenticationError(context.state, msg) from e - if self.sp.config.getattr('allow_unsolicited', 'sp') is False: + if self.sp.config.getattr("allow_unsolicited", "sp") is False: if req_id in self.outstanding_queries: msg = { "message": "AuthnRequest Failed", "error": f"Request with duplicate id {req_id}", } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) raise SATOSAAuthenticationError(context.state, msg) self.outstanding_queries[req_id] = req_id @@ -414,7 +418,9 @@ def authn_response(self, context, binding): "message": "Authentication failed", "error": "Received AuthN response without a SATOSA session cookie", } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) raise SATOSAMissingStateError(msg) @@ -424,7 +430,9 @@ def authn_response(self, context, binding): "message": "Authentication failed", "error": "SAML Response not found in context.request", } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) raise SATOSAAuthenticationError(context.state, msg) @@ -437,19 +445,23 @@ def authn_response(self, context, binding): "message": "Authentication failed", "error": f"Failed to parse Authn response: {e}", } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline, exc_info=True) logger.info(logline) raise SATOSAAuthenticationError(context.state, msg) from e - if self.sp.config.getattr('allow_unsolicited', 'sp') is False: + if self.sp.config.getattr("allow_unsolicited", "sp") is False: req_id = authn_response.in_response_to if req_id not in self.outstanding_queries: msg = { "message": "Authentication failed", "error": f"No corresponding request with id: {req_id}", } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) raise SATOSAAuthenticationError(context.state, msg) del self.outstanding_queries[req_id] @@ -460,7 +472,9 @@ def authn_response(self, context, binding): "message": "Authentication failed", "error": "Response state query param did not match relay state for request", } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) raise SATOSAAuthenticationError(context.state, msg) @@ -469,7 +483,9 @@ def authn_response(self, context, binding): issuer = authn_response.response.issuer.text.strip() context.state[Context.KEY_MEMORIZED_IDP] = issuer context.state.pop(Context.KEY_FORCE_AUTHN, None) - return self.auth_callback_func(context, self._translate_response(authn_response, context.state)) + return self.auth_callback_func( + context, self._translate_response(authn_response, context.state) + ) def disco_response(self, context): """ @@ -484,8 +500,10 @@ def disco_response(self, context): info = context.request state = context.state - if 'SATOSA_BASE' not in state: - raise SATOSAAuthenticationFlowError("Discovery response without AuthN request") + if "SATOSA_BASE" not in state: + raise SATOSAAuthenticationFlowError( + "Discovery response without AuthN request" + ) entity_id = info.get("entityID") msg = { @@ -519,9 +537,7 @@ def _translate_response(self, response, state): iter(response.authn_info()), [None, None, None] ) authenticating_authority = ( - authenticating_authorities[-1] - if authenticating_authorities - else None + authenticating_authorities[-1] if authenticating_authorities else None ) auth_info = AuthenticationInformation( auth_class_ref=authn_context_ref, @@ -536,7 +552,8 @@ def _translate_response(self, response, state): name_id_format = subject.format if subject else None attributes = self.converter.to_internal( - self.attribute_profile, response.ava, + self.attribute_profile, + response.ava, ) internal_resp = InternalData( @@ -553,10 +570,10 @@ def _translate_response(self, response, state): msg = { "message": "Attributes received by the backend", "issuer": issuer, - "attributes": " ".join(list(response.ava.keys())) + "attributes": " ".join(list(response.ava.keys())), } if name_id_format: - msg['name_id'] = name_id_format + msg["name_id"] = name_id_format logline = lu.LOG_FMT.format(id=lu.get_session_id(state), message=msg) logger.info(logline) @@ -571,7 +588,9 @@ def _metadata_endpoint(self, context): :param context: The current context :return: response with metadata """ - msg = "Sending metadata response for entityId = {}".format(self.sp.config.entityid) + msg = "Sending metadata response for entityId = {}".format( + self.sp.config.entityid + ) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.debug(logline) @@ -589,7 +608,12 @@ def register_endpoints(self): sp_endpoints = self.sp.config.getattr("endpoints", "sp") for endp, binding in sp_endpoints["assertion_consumer_service"]: parsed_endp = urlparse(endp) - url_map.append(("^%s$" % parsed_endp.path[1:], functools.partial(self.authn_response, binding=binding))) + url_map.append( + ( + "^%s$" % parsed_endp.path[1:], + functools.partial(self.authn_response, binding=binding), + ) + ) if binding == BINDING_HTTP_REDIRECT: msg = " ".join( [ @@ -608,18 +632,21 @@ def register_endpoints(self): if self.discosrv: for endp, binding in sp_endpoints["discovery_response"]: parsed_endp = urlparse(endp) - url_map.append( - ("^%s$" % parsed_endp.path[1:], self.disco_response)) + url_map.append(("^%s$" % parsed_endp.path[1:], self.disco_response)) if self.expose_entityid_endpoint(): - logger.debug("Exposing backend entity endpoint = {}".format(self.sp.config.entityid)) + logger.debug( + "Exposing backend entity endpoint = {}".format(self.sp.config.entityid) + ) parsed_entity_id = urlparse(self.sp.config.entityid) - url_map.append(("^{0}".format(parsed_entity_id.path[1:]), - self._metadata_endpoint)) + url_map.append( + ("^{0}".format(parsed_entity_id.path[1:]), self._metadata_endpoint) + ) if self.enable_metadata_reload(): url_map.append( - ("^%s/%s$" % (self.name, "reload-metadata"), self._reload_metadata)) + ("^%s/%s$" % (self.name, "reload-metadata"), self._reload_metadata) + ) return url_map @@ -629,7 +656,7 @@ def _reload_metadata(self, context): """ logger.debug("Reloading metadata") res = self.sp.reload_metadata( - copy.deepcopy(self.config[SAMLBackend.KEY_SP_CONFIG]['metadata']) + copy.deepcopy(self.config[SAMLBackend.KEY_SP_CONFIG]["metadata"]) ) message = "Metadata reload %s" % ("OK" if res else "failed") status = "200 OK" if res else "500 FAILED" @@ -644,7 +671,9 @@ def get_metadata_desc(self): idp_entities = self.sp.metadata.with_descriptor("idpsso") for entity_id, entity in idp_entities.items(): - description = MetadataDescription(urlsafe_b64encode(entity_id.encode("utf-8")).decode("utf-8")) + description = MetadataDescription( + urlsafe_b64encode(entity_id.encode("utf-8")).decode("utf-8") + ) # Add organization info try: @@ -655,8 +684,12 @@ def get_metadata_desc(self): organization = OrganizationDesc() for name_info in organization_info.get("organization_name", []): organization.add_name(name_info["text"], name_info["lang"]) - for display_name_info in organization_info.get("organization_display_name", []): - organization.add_display_name(display_name_info["text"], display_name_info["lang"]) + for display_name_info in organization_info.get( + "organization_display_name", [] + ): + organization.add_display_name( + display_name_info["text"], display_name_info["lang"] + ) for url_info in organization_info.get("organization_url", []): organization.add_url(url_info["text"], url_info["lang"]) description.organization = organization @@ -670,7 +703,7 @@ def get_metadata_desc(self): for person in contact_persons: person_desc = ContactPersonDesc() person_desc.contact_type = person.get("contact_type") - for address in person.get('email_address', []): + for address in person.get("email_address", []): person_desc.add_email_address(address["text"]) if "given_name" in person: person_desc.given_name = person["given_name"]["text"] @@ -680,7 +713,9 @@ def get_metadata_desc(self): description.add_contact_person(person_desc) # Add UI info - ui_info = self.sp.metadata.extension(entity_id, "idpsso_descriptor", "{}&UIInfo".format(UI_NAMESPACE)) + ui_info = self.sp.metadata.extension( + entity_id, "idpsso_descriptor", "{}&UIInfo".format(UI_NAMESPACE) + ) if ui_info: ui_info = ui_info[0] ui_info_desc = UIInfoDesc() @@ -689,14 +724,21 @@ def get_metadata_desc(self): for name in ui_info.get("display_name", []): ui_info_desc.add_display_name(name["text"], name["lang"]) for logo in ui_info.get("logo", []): - ui_info_desc.add_logo(logo["text"], logo["width"], logo["height"], logo.get("lang")) + ui_info_desc.add_logo( + logo["text"], logo["width"], logo["height"], logo.get("lang") + ) for keywords in ui_info.get("keywords", []): - ui_info_desc.add_keywords(keywords.get("text", []), keywords.get("lang")) + ui_info_desc.add_keywords( + keywords.get("text", []), keywords.get("lang") + ) for information_url in ui_info.get("information_url", []): - ui_info_desc.add_information_url(information_url.get("text"), information_url.get("lang")) + ui_info_desc.add_information_url( + information_url.get("text"), information_url.get("lang") + ) for privacy_statement_url in ui_info.get("privacy_statement_url", []): ui_info_desc.add_privacy_statement_url( - privacy_statement_url.get("text"), privacy_statement_url.get("lang") + privacy_statement_url.get("text"), + privacy_statement_url.get("lang"), ) description.ui_info = ui_info_desc @@ -708,26 +750,27 @@ class SAMLEIDASBackend(SAMLBackend, SAMLEIDASBaseModule): """ A saml2 eidas backend module (acting as a SP). """ - VALUE_ACR_CLASS_REF_DEFAULT = 'http://eidas.europa.eu/LoA/high' - VALUE_ACR_COMPARISON_DEFAULT = 'minimum' + + VALUE_ACR_CLASS_REF_DEFAULT = "http://eidas.europa.eu/LoA/high" + VALUE_ACR_COMPARISON_DEFAULT = "minimum" def init_config(self, config): config = super().init_config(config) spec_eidas_sp = { - 'acr_mapping': { + "acr_mapping": { "": { - 'class_ref': self.VALUE_ACR_CLASS_REF_DEFAULT, - 'comparison': self.VALUE_ACR_COMPARISON_DEFAULT, + "class_ref": self.VALUE_ACR_CLASS_REF_DEFAULT, + "comparison": self.VALUE_ACR_COMPARISON_DEFAULT, }, }, - 'sp_config.service.sp.authn_requests_signed': True, - 'sp_config.service.sp.want_response_signed': True, - 'sp_config.service.sp.allow_unsolicited': False, - 'sp_config.service.sp.force_authn': True, - 'sp_config.service.sp.hide_assertion_consumer_service': True, - 'sp_config.service.sp.sp_type': ['private', 'public'], - 'sp_config.service.sp.sp_type_in_metadata': [True, False], + "sp_config.service.sp.authn_requests_signed": True, + "sp_config.service.sp.want_response_signed": True, + "sp_config.service.sp.allow_unsolicited": False, + "sp_config.service.sp.force_authn": True, + "sp_config.service.sp.hide_assertion_consumer_service": True, + "sp_config.service.sp.sp_type": ["private", "public"], + "sp_config.service.sp.sp_type_in_metadata": [True, False], } return util.check_set_dict_defaults(config, spec_eidas_sp) diff --git a/src/satosa/base.py b/src/satosa/base.py index 1e17c8cbe..4862dfb9d 100644 --- a/src/satosa/base.py +++ b/src/satosa/base.py @@ -7,30 +7,29 @@ from saml2.s_utils import UnknownSystemEntity +import satosa.logging_util as lu from satosa import util -from satosa.response import BadRequest -from satosa.response import NotFound -from satosa.response import Redirect +from satosa.response import BadRequest, NotFound, Redirect + from .context import Context -from .exception import SATOSAAuthenticationError -from .exception import SATOSAAuthenticationFlowError -from .exception import SATOSABadRequestError -from .exception import SATOSAError -from .exception import SATOSAMissingStateError -from .exception import SATOSANoBoundEndpointError -from .exception import SATOSAUnknownError -from .exception import SATOSAStateError -from .plugin_loader import load_backends -from .plugin_loader import load_frontends -from .plugin_loader import load_request_microservices -from .plugin_loader import load_response_microservices +from .exception import ( + SATOSAAuthenticationError, + SATOSAAuthenticationFlowError, + SATOSABadRequestError, + SATOSAError, + SATOSAMissingStateError, + SATOSANoBoundEndpointError, + SATOSAStateError, + SATOSAUnknownError, +) +from .plugin_loader import ( + load_backends, + load_frontends, + load_request_microservices, + load_response_microservices, +) from .routing import ModuleRouter -from .state import State -from .state import cookie_to_state -from .state import state_to_cookie - -import satosa.logging_util as lu - +from .state import State, cookie_to_state, state_to_cookie logger = logging.getLogger(__name__) @@ -53,32 +52,51 @@ def __init__(self, config): self.config = config logger.info("Loading backend modules...") - backends = load_backends(self.config, self._auth_resp_callback_func, - self.config["INTERNAL_ATTRIBUTES"]) + backends = load_backends( + self.config, + self._auth_resp_callback_func, + self.config["INTERNAL_ATTRIBUTES"], + ) logger.info("Loading frontend modules...") - frontends = load_frontends(self.config, self._auth_req_callback_func, - self.config["INTERNAL_ATTRIBUTES"]) + frontends = load_frontends( + self.config, + self._auth_req_callback_func, + self.config["INTERNAL_ATTRIBUTES"], + ) self.response_micro_services = [] self.request_micro_services = [] logger.info("Loading micro services...") if "MICRO_SERVICES" in self.config: - self.request_micro_services.extend(load_request_microservices( - self.config.get("CUSTOM_PLUGIN_MODULE_PATHS"), - self.config["MICRO_SERVICES"], - self.config["INTERNAL_ATTRIBUTES"], - self.config["BASE"])) - self._link_micro_services(self.request_micro_services, self._auth_req_finish) + self.request_micro_services.extend( + load_request_microservices( + self.config.get("CUSTOM_PLUGIN_MODULE_PATHS"), + self.config["MICRO_SERVICES"], + self.config["INTERNAL_ATTRIBUTES"], + self.config["BASE"], + ) + ) + self._link_micro_services( + self.request_micro_services, self._auth_req_finish + ) self.response_micro_services.extend( - load_response_microservices(self.config.get("CUSTOM_PLUGIN_MODULE_PATHS"), - self.config["MICRO_SERVICES"], - self.config["INTERNAL_ATTRIBUTES"], - self.config["BASE"])) - self._link_micro_services(self.response_micro_services, self._auth_resp_finish) + load_response_microservices( + self.config.get("CUSTOM_PLUGIN_MODULE_PATHS"), + self.config["MICRO_SERVICES"], + self.config["INTERNAL_ATTRIBUTES"], + self.config["BASE"], + ) + ) + self._link_micro_services( + self.response_micro_services, self._auth_resp_finish + ) - self.module_router = ModuleRouter(frontends, backends, - self.request_micro_services + self.response_micro_services) + self.module_router = ModuleRouter( + frontends, + backends, + self.request_micro_services + self.response_micro_services, + ) def _link_micro_services(self, micro_services, finisher): if not micro_services: @@ -121,9 +139,13 @@ def _auth_req_finish(self, context, internal_request): return backend.start_auth(context, internal_request) def _auth_resp_finish(self, context, internal_response): - user_id_to_attr = self.config["INTERNAL_ATTRIBUTES"].get("user_id_to_attr", None) + user_id_to_attr = self.config["INTERNAL_ATTRIBUTES"].get( + "user_id_to_attr", None + ) if user_id_to_attr: - internal_response.attributes[user_id_to_attr] = [internal_response.subject_id] + internal_response.attributes[user_id_to_attr] = [ + internal_response.subject_id + ] # remove all session state unless CONTEXT_STATE_DELETE is False context.state.delete = self.config.get("CONTEXT_STATE_DELETE", True) @@ -152,14 +174,13 @@ def _auth_resp_callback_func(self, context, internal_response): # If configured construct the user id from attribute values. if "user_id_from_attrs" in self.config["INTERNAL_ATTRIBUTES"]: subject_id = [ - "".join(internal_response.attributes[attr]) for attr in - self.config["INTERNAL_ATTRIBUTES"]["user_id_from_attrs"] + "".join(internal_response.attributes[attr]) + for attr in self.config["INTERNAL_ATTRIBUTES"]["user_id_from_attrs"] ] internal_response.subject_id = "".join(subject_id) if self.response_micro_services: - return self.response_micro_services[0].process( - context, internal_response) + return self.response_micro_services[0].process(context, internal_response) return self._auth_resp_finish(context, internal_response) @@ -197,7 +218,9 @@ def _run_bound_endpoint(self, context, spec): msg = "ERROR_ID [{err_id}]\nSTATE:\n{state}".format( err_id=error.error_id, state=state ) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline, exc_info=True) return self._handle_satosa_authentication_error(error) @@ -219,7 +242,9 @@ def _load_state(self, context): finally: context.state = state msg = f"Loaded state {state} from cookie {context.cookie}" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) def _save_state(self, resp, context): @@ -247,8 +272,7 @@ def _save_state(self, resp, context): resp.headers = [ (name, value) for (name, value) in resp.headers - if name != "Set-Cookie" - or not value.startswith(f"{cookie_name}=") + if name != "Set-Cookie" or not value.startswith(f"{cookie_name}=") ] resp.headers.append(tuple(cookie.output().split(": ", 1))) @@ -274,7 +298,9 @@ def run(self, context): "error": str(e), "error_id": error_id, } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) generic_error_url = self.config.get("ERROR_URL") if generic_error_url: @@ -288,7 +314,9 @@ def run(self, context): "error": str(e), "error_id": error_id, } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) generic_error_url = self.config.get("ERROR_URL") if generic_error_url: @@ -302,7 +330,9 @@ def run(self, context): "error": str(e), "error_id": error_id, } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) generic_error_url = self.config.get("ERROR_URL") if generic_error_url: @@ -316,13 +346,17 @@ def run(self, context): "error": str(e), "error_id": error_id, } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) generic_error_url = self.config.get("ERROR_URL") if generic_error_url: redirect_url = f"{generic_error_url}?errorid={error_id}" return Redirect(generic_error_url) - return NotFound("The Service or Identity Provider you requested could not be found.") + return NotFound( + "The Service or Identity Provider you requested could not be found." + ) except SATOSAError as e: error_id = uuid.uuid4().urn msg = { @@ -330,7 +364,9 @@ def run(self, context): "error": str(e), "error_id": error_id, } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) generic_error_url = self.config.get("ERROR_URL") if generic_error_url: @@ -344,7 +380,9 @@ def run(self, context): "error": str(e), "error_id": error_id, } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) generic_error_url = self.config.get("ERROR_URL") if generic_error_url: @@ -358,7 +396,9 @@ def run(self, context): "error": str(e), "error_id": error_id, } - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) generic_error_url = self.config.get("ERROR_URL") if generic_error_url: @@ -369,16 +409,16 @@ def run(self, context): class SAMLBaseModule(object): - KEY_ENTITYID_ENDPOINT = 'entityid_endpoint' - KEY_ENABLE_METADATA_RELOAD = 'enable_metadata_reload' - KEY_ATTRIBUTE_PROFILE = 'attribute_profile' - KEY_ACR_MAPPING = 'acr_mapping' - VALUE_ATTRIBUTE_PROFILE_DEFAULT = 'saml' + KEY_ENTITYID_ENDPOINT = "entityid_endpoint" + KEY_ENABLE_METADATA_RELOAD = "enable_metadata_reload" + KEY_ATTRIBUTE_PROFILE = "attribute_profile" + KEY_ACR_MAPPING = "acr_mapping" + VALUE_ATTRIBUTE_PROFILE_DEFAULT = "saml" def init_config(self, config): self.attribute_profile = config.get( - self.KEY_ATTRIBUTE_PROFILE, - self.VALUE_ATTRIBUTE_PROFILE_DEFAULT) + self.KEY_ATTRIBUTE_PROFILE, self.VALUE_ATTRIBUTE_PROFILE_DEFAULT + ) self.acr_mapping = config.get(self.KEY_ACR_MAPPING) return config @@ -397,13 +437,13 @@ def enable_metadata_reload(self): class SAMLEIDASBaseModule(SAMLBaseModule): - VALUE_ATTRIBUTE_PROFILE_DEFAULT = 'eidas' + VALUE_ATTRIBUTE_PROFILE_DEFAULT = "eidas" def init_config(self, config): config = super().init_config(config) spec_eidas = { - 'entityid_endpoint': True, + "entityid_endpoint": True, } return util.check_set_dict_defaults(config, spec_eidas) diff --git a/src/satosa/context.py b/src/satosa/context.py index 2cd8243ac..2efb72a49 100644 --- a/src/satosa/context.py +++ b/src/satosa/context.py @@ -7,13 +7,14 @@ class Context(object): """ Holds methods for sharing proxy data through the current request """ - KEY_METADATA_STORE = 'metadata_store' - KEY_TARGET_ENTITYID = 'target_entity_id' - KEY_FORCE_AUTHN = 'force_authn' - KEY_MEMORIZED_IDP = 'memorized_idp' - KEY_REQUESTER_METADATA = 'requester_metadata' - KEY_AUTHN_CONTEXT_CLASS_REF = 'authn_context_class_ref' - KEY_TARGET_AUTHN_CONTEXT_CLASS_REF = 'target_authn_context_class_ref' + + KEY_METADATA_STORE = "metadata_store" + KEY_TARGET_ENTITYID = "target_entity_id" + KEY_FORCE_AUTHN = "force_authn" + KEY_MEMORIZED_IDP = "memorized_idp" + KEY_REQUESTER_METADATA = "requester_metadata" + KEY_AUTHN_CONTEXT_CLASS_REF = "authn_context_class_ref" + KEY_TARGET_AUTHN_CONTEXT_CLASS_REF = "target_authn_context_class_ref" def __init__(self): self._path = None @@ -66,7 +67,7 @@ def path(self, p): """ if not p: raise ValueError("path can't be set to None") - elif p.startswith('/'): + elif p.startswith("/"): raise ValueError("path can't start with '/'") self._path = p diff --git a/src/satosa/cookies.py b/src/satosa/cookies.py index 718fdb784..9b82cdef7 100644 --- a/src/satosa/cookies.py +++ b/src/satosa/cookies.py @@ -1,6 +1,5 @@ import http.cookies as _http_cookies - _http_cookies.Morsel._reserved["samesite"] = "SameSite" SimpleCookie = _http_cookies.SimpleCookie diff --git a/src/satosa/exception.py b/src/satosa/exception.py index 770d26283..424098b79 100644 --- a/src/satosa/exception.py +++ b/src/satosa/exception.py @@ -7,6 +7,7 @@ class SATOSAError(Exception): """ Base SATOSA exception """ + pass @@ -14,6 +15,7 @@ class SATOSAConfigurationError(SATOSAError): """ SATOSA configuration error """ + pass @@ -21,6 +23,7 @@ class SATOSAStateError(SATOSAError): """ SATOSA state error. """ + pass @@ -28,6 +31,7 @@ class SATOSACriticalError(SATOSAError): """ SATOSA critical error """ + pass @@ -35,6 +39,7 @@ class SATOSAUnknownError(SATOSAError): """ SATOSA unknown error """ + pass @@ -73,6 +78,7 @@ class SATOSABasicError(SATOSAError): """ eduTEAMS error """ + def __init__(self, error): self.error = error @@ -85,6 +91,7 @@ class SATOSAMissingStateError(SATOSABasicError): an authentication flow and while the session state cookie is expected for that step, it is not included in the request """ + pass @@ -96,6 +103,7 @@ class SATOSAAuthenticationFlowError(SATOSABasicError): be serviced because previous steps in the authentication flow for that session cannot be found """ + pass @@ -105,6 +113,7 @@ class SATOSABadRequestError(SATOSABasicError): This exception should be raised when we want to return an HTTP 400 Bad Request """ + pass @@ -112,6 +121,7 @@ class SATOSABadContextError(SATOSAError): """ Raise this exception if validating the Context and failing. """ + pass @@ -119,4 +129,5 @@ class SATOSANoBoundEndpointError(SATOSAError): """ Raised when a given url path is not bound to any endpoint function """ + pass diff --git a/src/satosa/frontends/openid_connect.py b/src/satosa/frontends/openid_connect.py index 88041b373..336b38de5 100644 --- a/src/satosa/frontends/openid_connect.py +++ b/src/satosa/frontends/openid_connect.py @@ -7,41 +7,42 @@ from collections import defaultdict from urllib.parse import urlencode, urlparse -from jwkest.jwk import rsa_load, RSAKey - +from jwkest.jwk import RSAKey, rsa_load from oic.oic import scope2claims -from oic.oic.message import AuthorizationRequest -from oic.oic.message import AuthorizationErrorResponse -from oic.oic.message import TokenErrorResponse -from oic.oic.message import UserInfoErrorResponse -from oic.oic.provider import RegistrationEndpoint -from oic.oic.provider import AuthorizationEndpoint -from oic.oic.provider import TokenEndpoint -from oic.oic.provider import UserinfoEndpoint - +from oic.oic.message import ( + AuthorizationErrorResponse, + AuthorizationRequest, + TokenErrorResponse, + UserInfoErrorResponse, +) +from oic.oic.provider import ( + AuthorizationEndpoint, + RegistrationEndpoint, + TokenEndpoint, + UserinfoEndpoint, +) from pyop.access_token import AccessToken from pyop.authz_state import AuthorizationState -from pyop.exceptions import InvalidAuthenticationRequest -from pyop.exceptions import InvalidClientRegistrationRequest -from pyop.exceptions import InvalidClientAuthentication -from pyop.exceptions import OAuthError -from pyop.exceptions import BearerTokenError -from pyop.exceptions import InvalidAccessToken +from pyop.exceptions import ( + BearerTokenError, + InvalidAccessToken, + InvalidAuthenticationRequest, + InvalidClientAuthentication, + InvalidClientRegistrationRequest, + OAuthError, +) from pyop.provider import Provider from pyop.storage import StorageBase from pyop.subject_identifier import HashBasedSubjectIdentifierFactory from pyop.userinfo import Userinfo from pyop.util import should_fragment_encode -from .base import FrontendModule -from ..response import BadRequest, Created -from ..response import SeeOther, Response -from ..response import Unauthorized -from ..util import rndstr - import satosa.logging_util as lu from satosa.internal import InternalData +from ..response import BadRequest, Created, Response, SeeOther, Unauthorized +from ..util import rndstr +from .base import FrontendModule logger = logging.getLogger(__name__) @@ -56,7 +57,9 @@ class OpenIDConnectFrontend(FrontendModule): A OpenID Connect frontend module """ - def __init__(self, auth_req_callback_func, internal_attributes, conf, base_url, name): + def __init__( + self, auth_req_callback_func, internal_attributes, conf, base_url, name + ): _validate_config(conf) super().__init__(auth_req_callback_func, internal_attributes, base_url, name) @@ -135,15 +138,18 @@ def handle_authn_response(self, context, internal_resp): auth_resp = self.provider.authorize( auth_req, internal_resp.subject_id, - extra_id_token_claims=lambda user_id, client_id: - self._get_extra_id_token_claims(user_id, client_id), + extra_id_token_claims=lambda user_id, client_id: self._get_extra_id_token_claims( + user_id, client_id + ), ) if self.stateless: del self.user_db[internal_resp.subject_id] del context.state[self.name] - http_response = auth_resp.request(auth_req["redirect_uri"], should_fragment_encode(auth_req)) + http_response = auth_resp.request( + auth_req["redirect_uri"], should_fragment_encode(auth_req) + ) return SeeOther(http_response) def handle_backend_error(self, exception): @@ -154,17 +160,24 @@ def handle_backend_error(self, exception): """ auth_req = self._get_authn_request_from_state(exception.state) # If the client sent us a state parameter, we should reflect it back according to the spec - if 'state' in auth_req: - error_resp = AuthorizationErrorResponse(error="access_denied", - error_description=exception.message, - state=auth_req['state']) + if "state" in auth_req: + error_resp = AuthorizationErrorResponse( + error="access_denied", + error_description=exception.message, + state=auth_req["state"], + ) else: - error_resp = AuthorizationErrorResponse(error="access_denied", - error_description=exception.message) + error_resp = AuthorizationErrorResponse( + error="access_denied", error_description=exception.message + ) msg = exception.message logline = lu.LOG_FMT.format(id=lu.get_session_id(exception.state), message=msg) logger.debug(logline) - return SeeOther(error_resp.request(auth_req["redirect_uri"], should_fragment_encode(auth_req))) + return SeeOther( + error_resp.request( + auth_req["redirect_uri"], should_fragment_encode(auth_req) + ) + ) def register_endpoints(self, backend_names): """ @@ -194,8 +207,12 @@ def register_endpoints(self, backend_names): if backend_name: # if there is only one backend, include its name in the path so the default routing can work - auth_endpoint = "{}/{}/{}/{}".format(self.base_url, backend_name, self.name, AuthorizationEndpoint.url) - self.provider.configuration_information["authorization_endpoint"] = auth_endpoint + auth_endpoint = "{}/{}/{}/{}".format( + self.base_url, backend_name, self.name, AuthorizationEndpoint.url + ) + self.provider.configuration_information[ + "authorization_endpoint" + ] = auth_endpoint auth_path = urlparse(auth_endpoint).path.lstrip("/") else: auth_path = "{}/{}".format(self.name, AuthorizationEndpoint.url) @@ -203,20 +220,25 @@ def register_endpoints(self, backend_names): authentication = ("^{}$".format(auth_path), self.handle_authn_request) url_map = [provider_config, jwks_uri, authentication] - if any("code" in v for v in self.provider.configuration_information["response_types_supported"]): + if any( + "code" in v + for v in self.provider.configuration_information["response_types_supported"] + ): self.provider.configuration_information["token_endpoint"] = "{}/{}".format( self.endpoint_baseurl, TokenEndpoint.url ) token_endpoint = ( - "^{}/{}".format(self.name, TokenEndpoint.url), self.token_endpoint + "^{}/{}".format(self.name, TokenEndpoint.url), + self.token_endpoint, ) url_map.append(token_endpoint) - self.provider.configuration_information["userinfo_endpoint"] = ( - "{}/{}".format(self.endpoint_baseurl, UserinfoEndpoint.url) - ) + self.provider.configuration_information[ + "userinfo_endpoint" + ] = "{}/{}".format(self.endpoint_baseurl, UserinfoEndpoint.url) userinfo_endpoint = ( - "^{}/{}".format(self.name, UserinfoEndpoint.url), self.userinfo_endpoint + "^{}/{}".format(self.name, UserinfoEndpoint.url), + self.userinfo_endpoint, ) url_map.append(userinfo_endpoint) @@ -250,7 +272,9 @@ def client_registration(self, context): :return: HTTP response to the client """ try: - resp = self.provider.handle_client_registration_request(json.dumps(context.request)) + resp = self.provider.handle_client_registration_request( + json.dumps(context.request) + ) return Created(resp.to_json(), content="application/json") except InvalidClientRegistrationRequest as e: return BadRequest(e.to_json(), content="application/json") @@ -264,7 +288,9 @@ def provider_config(self, context): :param context: the current context :return: HTTP response to the client """ - return Response(self.provider.provider_configuration.to_json(), content="application/json") + return Response( + self.provider.provider_configuration.to_json(), content="application/json" + ) def _get_approved_attributes(self, provider_supported_claims, authn_req): requested_claims = list( @@ -296,7 +322,9 @@ def _handle_authn_request(self, context): authn_req = self.provider.parse_authentication_request(request) except InvalidAuthenticationRequest as e: msg = "Error in authn req: {}".format(str(e)) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) error_url = e.to_error_url() @@ -321,8 +349,11 @@ def _handle_authn_request(self, context): ) internal_req.attributes = self.converter.to_internal_filter( - "openid", self._get_approved_attributes(self.provider.configuration_information["claims_supported"], - authn_req)) + "openid", + self._get_approved_attributes( + self.provider.configuration_information["claims_supported"], authn_req + ), + ) return internal_req def handle_authn_request(self, context): @@ -364,19 +395,29 @@ def token_endpoint(self, context): response = self.provider.handle_token_request( urlencode(context.request), headers, - lambda user_id, client_id: self._get_extra_id_token_claims(user_id, client_id)) + lambda user_id, client_id: self._get_extra_id_token_claims( + user_id, client_id + ), + ) return Response(response.to_json(), content="application/json") except InvalidClientAuthentication as e: logline = "invalid client authentication at token endpoint" logger.debug(logline, exc_info=True) - error_resp = TokenErrorResponse(error='invalid_client', error_description=str(e)) - response = Unauthorized(error_resp.to_json(), headers=[("WWW-Authenticate", "Basic")], - content="application/json") + error_resp = TokenErrorResponse( + error="invalid_client", error_description=str(e) + ) + response = Unauthorized( + error_resp.to_json(), + headers=[("WWW-Authenticate", "Basic")], + content="application/json", + ) return response except OAuthError as e: logline = "invalid request: {}".format(str(e)) logger.debug(logline, exc_info=True) - error_resp = TokenErrorResponse(error=e.oauth_error, error_description=str(e)) + error_resp = TokenErrorResponse( + error=e.oauth_error, error_description=str(e) + ) return BadRequest(error_resp.to_json(), content="application/json") def userinfo_endpoint(self, context): @@ -389,9 +430,14 @@ def userinfo_endpoint(self, context): ) return Response(response.to_json(), content="application/json") except (BearerTokenError, InvalidAccessToken) as e: - error_resp = UserInfoErrorResponse(error='invalid_token', error_description=str(e)) - response = Unauthorized(error_resp.to_json(), headers=[("WWW-Authenticate", AccessToken.BEARER_TOKEN_TYPE)], - content="application/json") + error_resp = UserInfoErrorResponse( + error="invalid_token", error_description=str(e) + ) + response = Unauthorized( + error_resp.to_json(), + headers=[("WWW-Authenticate", AccessToken.BEARER_TOKEN_TYPE)], + content="application/json", + ) return response @@ -406,11 +452,16 @@ def _validate_config(config): for k in {"signing_key_path", "provider"}: if k not in config: - raise ValueError("Missing configuration parameter '{}' for OpenID Connect frontend.".format(k)) + raise ValueError( + "Missing configuration parameter '{}' for OpenID Connect frontend.".format( + k + ) + ) if "signing_key_id" in config and type(config["signing_key_id"]) is not str: raise ValueError( - "The configuration parameter 'signing_key_id' is not defined as a string for OpenID Connect frontend.") + "The configuration parameter 'signing_key_id' is not defined as a string for OpenID Connect frontend." + ) def _create_provider( @@ -422,13 +473,19 @@ def _create_provider( user_db, cdb, ): - response_types_supported = provider_config.get("response_types_supported", ["id_token"]) - subject_types_supported = provider_config.get("subject_types_supported", ["pairwise"]) + response_types_supported = provider_config.get( + "response_types_supported", ["id_token"] + ) + subject_types_supported = provider_config.get( + "subject_types_supported", ["pairwise"] + ) scopes_supported = provider_config.get("scopes_supported", ["openid"]) extra_scopes = provider_config.get("extra_scopes") capabilities = { "issuer": provider_config["issuer"], - "authorization_endpoint": "{}/{}".format(endpoint_baseurl, AuthorizationEndpoint.url), + "authorization_endpoint": "{}/{}".format( + endpoint_baseurl, AuthorizationEndpoint.url + ), "jwks_uri": "{}/jwks".format(endpoint_baseurl), "response_types_supported": response_types_supported, "id_token_signing_alg_values_supported": [signing_key.alg], @@ -443,10 +500,10 @@ def _create_provider( ], "request_parameter_supported": False, "request_uri_parameter_supported": False, - "scopes_supported": scopes_supported + "scopes_supported": scopes_supported, } - if 'code' in response_types_supported: + if "code" in response_types_supported: capabilities["token_endpoint"] = "{}/{}".format( endpoint_baseurl, TokenEndpoint.url ) @@ -566,7 +623,6 @@ def combine_join_by_space(values): def combine_claim_values(claim_items): claims = ( - (name, combine_values_by_claim[name](values)) - for name, values in claim_items + (name, combine_values_by_claim[name](values)) for name, values in claim_items ) return claims diff --git a/src/satosa/frontends/ping.py b/src/satosa/frontends/ping.py index 27fec279c..28ec9fac7 100644 --- a/src/satosa/frontends/ping.py +++ b/src/satosa/frontends/ping.py @@ -4,7 +4,6 @@ from satosa.frontends.base import FrontendModule from satosa.response import Response - logger = logging.getLogger(__name__) @@ -14,7 +13,9 @@ class PingFrontend(FrontendModule): 200 OK, intended to be used as a simple heartbeat monitor. """ - def __init__(self, auth_req_callback_func, internal_attributes, config, base_url, name): + def __init__( + self, auth_req_callback_func, internal_attributes, config, base_url, name + ): super().__init__(auth_req_callback_func, internal_attributes, base_url, name) self.config = config diff --git a/src/satosa/frontends/saml2.py b/src/satosa/frontends/saml2.py index cecd533db..78e47c45a 100644 --- a/src/satosa/frontends/saml2.py +++ b/src/satosa/frontends/saml2.py @@ -6,41 +6,34 @@ import json import logging import re -from base64 import urlsafe_b64decode -from base64 import urlsafe_b64encode -from urllib.parse import quote -from urllib.parse import quote_plus -from urllib.parse import unquote -from urllib.parse import unquote_plus -from urllib.parse import urlparse +from base64 import urlsafe_b64decode, urlsafe_b64encode from http.cookies import SimpleCookie +from urllib.parse import quote, quote_plus, unquote, unquote_plus, urlparse from saml2 import SAMLError, xmldsig from saml2.config import IdPConfig from saml2.extension.mdui import NAMESPACE as UI_NAMESPACE from saml2.metadata import create_metadata_string -from saml2.saml import NameID -from saml2.saml import NAMEID_FORMAT_TRANSIENT -from saml2.saml import NAMEID_FORMAT_PERSISTENT -from saml2.saml import NAMEID_FORMAT_EMAILADDRESS -from saml2.saml import NAMEID_FORMAT_UNSPECIFIED +from saml2.saml import ( + NAMEID_FORMAT_EMAILADDRESS, + NAMEID_FORMAT_PERSISTENT, + NAMEID_FORMAT_TRANSIENT, + NAMEID_FORMAT_UNSPECIFIED, + NameID, +) from saml2.samlp import name_id_policy_from_string from saml2.server import Server +import satosa.logging_util as lu +import satosa.util as util from satosa.base import SAMLBaseModule from satosa.context import Context -from .base import FrontendModule -from ..response import Response -from ..response import ServiceError -from ..saml_util import make_saml_response -from satosa.exception import SATOSAError -from satosa.exception import SATOSABadRequestError -from satosa.exception import SATOSAMissingStateError -import satosa.util as util - -import satosa.logging_util as lu +from satosa.exception import SATOSABadRequestError, SATOSAError, SATOSAMissingStateError from satosa.internal import InternalData +from ..response import Response, ServiceError +from ..saml_util import make_saml_response +from .base import FrontendModule logger = logging.getLogger(__name__) @@ -63,19 +56,21 @@ class SAMLFrontend(FrontendModule, SAMLBaseModule): """ A pysaml2 frontend module """ - KEY_CUSTOM_ATTR_RELEASE = 'custom_attribute_release' - KEY_ENDPOINTS = 'endpoints' - KEY_IDP_CONFIG = 'idp_config' - def __init__(self, auth_req_callback_func, internal_attributes, config, base_url, name): + KEY_CUSTOM_ATTR_RELEASE = "custom_attribute_release" + KEY_ENDPOINTS = "endpoints" + KEY_IDP_CONFIG = "idp_config" + + def __init__( + self, auth_req_callback_func, internal_attributes, config, base_url, name + ): self._validate_config(config) super().__init__(auth_req_callback_func, internal_attributes, base_url, name) self.config = self.init_config(config) self.endpoints = config[self.KEY_ENDPOINTS] - self.custom_attribute_release = config.get( - self.KEY_CUSTOM_ATTR_RELEASE) + self.custom_attribute_release = config.get(self.KEY_CUSTOM_ATTR_RELEASE) self.idp = None def handle_authn_response(self, context, internal_response): @@ -119,10 +114,12 @@ def register_endpoints(self, backend_names): if self.enable_metadata_reload(): url_map.append( - ("^%s/%s$" % (self.name, "reload-metadata"), self._reload_metadata)) + ("^%s/%s$" % (self.name, "reload-metadata"), self._reload_metadata) + ) self.idp_config = self._build_idp_config_endpoints( - self.config[self.KEY_IDP_CONFIG], backend_names) + self.config[self.KEY_IDP_CONFIG], backend_names + ) # Create the idp idp_config = IdPConfig().load(copy.deepcopy(self.idp_config)) self.idp = Server(config=idp_config) @@ -143,7 +140,9 @@ def _create_state_data(self, context, resp_args, relay_state): :return: A state as a dict """ if "name_id_policy" in resp_args and resp_args["name_id_policy"] is not None: - resp_args["name_id_policy"] = resp_args["name_id_policy"].to_string().decode("utf-8") + resp_args["name_id_policy"] = ( + resp_args["name_id_policy"].to_string().decode("utf-8") + ) return {"resp_args": resp_args, "relay_state": relay_state} def load_state(self, state): @@ -168,12 +167,13 @@ def load_state(self, state): of their browser and resend the authentication response, but without the SATOSA session cookie """ - error = "Received AuthN response without a SATOSA session cookie" + error = "Received AuthN response without a SATOSA session cookie" raise SATOSAMissingStateError(error) if isinstance(state_data["resp_args"]["name_id_policy"], str): state_data["resp_args"]["name_id_policy"] = name_id_policy_from_string( - state_data["resp_args"]["name_id_policy"]) + state_data["resp_args"]["name_id_policy"] + ) return state_data def _validate_config(self, config): @@ -210,13 +210,17 @@ def _handle_authn_request(self, context, binding_in, idp): """ try: - req_info = idp.parse_authn_request(context.request["SAMLRequest"], binding_in) + req_info = idp.parse_authn_request( + context.request["SAMLRequest"], binding_in + ) except KeyError: """ HTTP clients that call the SSO endpoint without sending SAML AuthN request will receive a "400 Bad Request" response """ - raise SATOSABadRequestError("HTTP request does not include a SAML AuthN request") + raise SATOSABadRequestError( + "HTTP request does not include a SAML AuthN request" + ) authn_req = req_info.message msg = "{}".format(authn_req) @@ -230,13 +234,16 @@ def _handle_authn_request(self, context, binding_in, idp): resp_args = idp.response_args(authn_req) except SAMLError as e: msg = "Could not find necessary info about entity: {}".format(e) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) return ServiceError("Incorrect request from requester: %s" % e) requester = resp_args["sp_entity_id"] - context.state[self.name] = self._create_state_data(context, idp.response_args(authn_req), - context.request.get("RelayState")) + context.state[self.name] = self._create_state_data( + context, idp.response_args(authn_req), context.request.get("RelayState") + ) subject = authn_req.subject name_id_value = subject.name_id.text if subject else None @@ -279,7 +286,7 @@ def _handle_authn_request(self, context, binding_in, idp): ) authn_context_class_ref_nodes = getattr( - authn_req.requested_authn_context, 'authn_context_class_ref', [] + authn_req.requested_authn_context, "authn_context_class_ref", [] ) authn_context = [ref.text for ref in authn_context_class_ref_nodes] context.decorate(Context.KEY_AUTHN_CONTEXT_CLASS_REF, authn_context) @@ -309,15 +316,24 @@ def _get_approved_attributes(self, idp, idp_policy, sp_entity_id, state): for aconv in attrconvs: if aconv.name_format == name_format: all_attributes = {v: None for v in aconv._fro.values()} - attribute_filter = list(idp_policy.restrict(all_attributes, sp_entity_id).keys()) + attribute_filter = list( + idp_policy.restrict(all_attributes, sp_entity_id).keys() + ) break - attribute_filter = self.converter.to_internal_filter(self.attribute_profile, attribute_filter) + attribute_filter = self.converter.to_internal_filter( + self.attribute_profile, attribute_filter + ) msg = "Filter: {}".format(attribute_filter) logline = lu.LOG_FMT.format(id=lu.get_session_id(state), message=msg) logger.debug(logline) return attribute_filter - def _filter_attributes(self, idp, internal_response, context,): + def _filter_attributes( + self, + idp, + internal_response, + context, + ): idp_policy = idp.config.getattr("policy", "idp") attributes = {} if idp_policy: @@ -350,14 +366,17 @@ def _handle_authn_response(self, context, internal_response, idp): resp_args = request_state["resp_args"] sp_entity_id = resp_args["sp_entity_id"] internal_response.attributes = self._filter_attributes( - idp, internal_response, context) + idp, internal_response, context + ) ava = self.converter.from_internal( - self.attribute_profile, internal_response.attributes) + self.attribute_profile, internal_response.attributes + ) auth_info = {} if self.acr_mapping: auth_info["class_ref"] = self.acr_mapping.get( - internal_response.auth_info.issuer, self.acr_mapping[""]) + internal_response.auth_info.issuer, self.acr_mapping[""] + ) else: auth_info["class_ref"] = internal_response.auth_info.auth_class_ref @@ -367,7 +386,8 @@ def _handle_authn_response(self, context, internal_response, idp): custom_release = util.get_dict_defaults( self.custom_attribute_release, internal_response.auth_info.issuer, - sp_entity_id) + sp_entity_id, + ) attributes_to_remove = custom_release.get("exclude", []) for k in attributes_to_remove: ava.pop(k, None) @@ -382,31 +402,37 @@ def _handle_authn_response(self, context, internal_response, idp): # Instead pass None as the name name_id to the IdP server # instance and it will use its configured policy to construct # a , with the default to create a transient . - name_id = None if not nameid_value else NameID( - text=nameid_value, - format=nameid_format, - sp_name_qualifier=None, - name_qualifier=None, + name_id = ( + None + if not nameid_value + else NameID( + text=nameid_value, + format=nameid_format, + sp_name_qualifier=None, + name_qualifier=None, + ) ) msg = "returning attributes {}".format(json.dumps(ava)) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.debug(logline) - idp_conf = self.idp_config.get('service', {}).get('idp', {}) - policies = idp_conf.get('policy', {}) - sp_policy = policies.get('default', {}) + idp_conf = self.idp_config.get("service", {}).get("idp", {}) + policies = idp_conf.get("policy", {}) + sp_policy = policies.get("default", {}) sp_policy.update(policies.get(sp_entity_id, {})) - sign_assertion = sp_policy.get('sign_assertion', False) - sign_response = sp_policy.get('sign_response', True) - encrypt_assertion = sp_policy.get('encrypt_assertion', False) - encrypted_advice_attributes = sp_policy.get('encrypted_advice_attributes', False) + sign_assertion = sp_policy.get("sign_assertion", False) + sign_response = sp_policy.get("sign_response", True) + encrypt_assertion = sp_policy.get("encrypt_assertion", False) + encrypted_advice_attributes = sp_policy.get( + "encrypted_advice_attributes", False + ) - signing_algorithm = idp_conf.get('signing_algorithm') - digest_algorithm = idp_conf.get('digest_algorithm') - sign_alg_attr = sp_policy.get('sign_alg', 'SIG_RSA_SHA256') - digest_alg_attr = sp_policy.get('digest_alg', 'DIGEST_SHA256') + signing_algorithm = idp_conf.get("signing_algorithm") + digest_algorithm = idp_conf.get("digest_algorithm") + sign_alg_attr = sp_policy.get("sign_alg", "SIG_RSA_SHA256") + digest_alg_attr = sp_policy.get("digest_alg", "DIGEST_SHA256") # Construct arguments for method create_authn_response # on IdP Server instance @@ -414,40 +440,44 @@ def _handle_authn_response(self, context, internal_response, idp): # Add the SP details **resp_args, # AuthnResponse data - 'identity': ava, - 'name_id': name_id, - 'authn': auth_info, - 'sign_response': sign_response, - 'sign_assertion': sign_assertion, - 'encrypt_assertion': encrypt_assertion, - 'encrypted_advice_attributes': encrypted_advice_attributes, + "identity": ava, + "name_id": name_id, + "authn": auth_info, + "sign_response": sign_response, + "sign_assertion": sign_assertion, + "encrypt_assertion": encrypt_assertion, + "encrypted_advice_attributes": encrypted_advice_attributes, } - args['sign_alg'] = signing_algorithm - if not args['sign_alg']: + args["sign_alg"] = signing_algorithm + if not args["sign_alg"]: try: - args['sign_alg'] = getattr(xmldsig, sign_alg_attr) + args["sign_alg"] = getattr(xmldsig, sign_alg_attr) except AttributeError as e: msg = "Unsupported sign algorithm {}".format(sign_alg_attr) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) raise Exception(msg) from e - msg = "signing with algorithm {}".format(args['sign_alg']) + msg = "signing with algorithm {}".format(args["sign_alg"]) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.debug(logline) - args['digest_alg'] = digest_algorithm - if not args['digest_alg']: + args["digest_alg"] = digest_algorithm + if not args["digest_alg"]: try: - args['digest_alg'] = getattr(xmldsig, digest_alg_attr) + args["digest_alg"] = getattr(xmldsig, digest_alg_attr) except AttributeError as e: msg = "Unsupported digest algorithm {}".format(digest_alg_attr) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) raise Exception(msg) from e - msg = "using digest algorithm {}".format(args['digest_alg']) + msg = "using digest algorithm {}".format(args["digest_alg"]) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.debug(logline) @@ -458,16 +488,22 @@ def _handle_authn_response(self, context, internal_response, idp): "under the service/idp configuration path " "(not under policy/default)." ) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.warning(msg) resp = idp.create_authn_response(**args) http_args = idp.apply_binding( - resp_args["binding"], str(resp), resp_args["destination"], - request_state["relay_state"], response=True) + resp_args["binding"], + str(resp), + resp_args["destination"], + request_state["relay_state"], + response=True, + ) # Set the common domain cookie _saml_idp if so configured. - if self.config.get('common_domain_cookie'): + if self.config.get("common_domain_cookie"): self._set_common_domain_cookie(internal_response, http_args, context) del context.state[self.name] @@ -479,10 +515,10 @@ def _handle_authn_response(self, context, internal_response, idp): "signed response": sign_response, "signed assertion": sign_assertion, "encrypted": encrypt_assertion, - "attributes": " ".join(list(ava.keys())) + "attributes": " ".join(list(ava.keys())), } if nameid_format: - msg['name_id'] = nameid_format + msg["name_id"] = nameid_format logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.info(logline) @@ -503,11 +539,18 @@ def _handle_backend_error(self, exception, idp): loaded_state = self.load_state(exception.state) relay_state = loaded_state["relay_state"] resp_args = loaded_state["resp_args"] - error_resp = idp.create_error_response(resp_args["in_response_to"], - resp_args["destination"], - Exception(exception.message)) - http_args = idp.apply_binding(resp_args["binding"], str(error_resp), resp_args["destination"], relay_state, - response=True) + error_resp = idp.create_error_response( + resp_args["in_response_to"], + resp_args["destination"], + Exception(exception.message), + ) + http_args = idp.apply_binding( + resp_args["binding"], + str(error_resp), + resp_args["destination"], + relay_state, + response=True, + ) msg = "HTTPSards: {}".format(http_args) logline = lu.LOG_FMT.format(id=lu.get_session_id(exception.state), message=msg) @@ -523,7 +566,9 @@ def _metadata_endpoint(self, context): :param context: The current context :return: response with metadata """ - msg = "Sending metadata response for entityId = {}".format(self.idp.config.entityid) + msg = "Sending metadata response for entityId = {}".format( + self.idp.config.entityid + ) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.debug(logline) metadata_string = create_metadata_string( @@ -537,7 +582,7 @@ def _reload_metadata(self, context): """ logger.debug("Reloading metadata") res = self.idp.reload_metadata( - copy.deepcopy(self.config[SAMLFrontend.KEY_IDP_CONFIG]['metadata']) + copy.deepcopy(self.config[SAMLFrontend.KEY_IDP_CONFIG]["metadata"]) ) message = "Metadata reload %s" % ("OK" if res else "failed") status = "200 OK" if res else "500 FAILED" @@ -560,34 +605,50 @@ def _register_endpoints(self, providers): valid_providers = "{}|^{}".format(valid_providers, provider) valid_providers = valid_providers.lstrip("|") parsed_endp = urlparse(endp) - url_map.append(("(%s)/%s$" % (valid_providers, parsed_endp.path), - functools.partial(self.handle_authn_request, binding_in=binding))) + url_map.append( + ( + "(%s)/%s$" % (valid_providers, parsed_endp.path), + functools.partial( + self.handle_authn_request, binding_in=binding + ), + ) + ) if self.expose_entityid_endpoint(): - logger.debug("Exposing frontend entity endpoint = {}".format(self.idp.config.entityid)) + logger.debug( + "Exposing frontend entity endpoint = {}".format( + self.idp.config.entityid + ) + ) parsed_entity_id = urlparse(self.idp.config.entityid) - url_map.append(("^{0}".format(parsed_entity_id.path[1:]), - self._metadata_endpoint)) + url_map.append( + ("^{0}".format(parsed_entity_id.path[1:]), self._metadata_endpoint) + ) return url_map def _set_common_domain_cookie(self, internal_response, http_args, context): - """ - """ + """ """ # Find any existing common domain cookie and deconsruct it to # obtain the list of IdPs. cookie = SimpleCookie(context.cookie) - if '_saml_idp' in cookie: - common_domain_cookie = cookie['_saml_idp'] + if "_saml_idp" in cookie: + common_domain_cookie = cookie["_saml_idp"] msg = "Found existing common domain cookie {}".format(common_domain_cookie) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) space_separated_b64_idp_string = unquote(common_domain_cookie.value) b64_idp_list = space_separated_b64_idp_string.split() - idp_list = [urlsafe_b64decode(b64_idp).decode('utf-8') for b64_idp in b64_idp_list] + idp_list = [ + urlsafe_b64decode(b64_idp).decode("utf-8") for b64_idp in b64_idp_list + ] else: msg = "No existing common domain cookie found" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) idp_list = [] @@ -611,33 +672,37 @@ def _set_common_domain_cookie(self, internal_response, http_args, context): logger.debug(logline) # Construct the cookie. - b64_idp_list = [urlsafe_b64encode(idp.encode()).decode("utf-8") for idp in idp_list] + b64_idp_list = [ + urlsafe_b64encode(idp.encode()).decode("utf-8") for idp in idp_list + ] space_separated_b64_idp_string = " ".join(b64_idp_list) - url_encoded_space_separated_b64_idp_string = quote(space_separated_b64_idp_string) + url_encoded_space_separated_b64_idp_string = quote( + space_separated_b64_idp_string + ) cookie = SimpleCookie() - cookie['_saml_idp'] = url_encoded_space_separated_b64_idp_string - cookie['_saml_idp']['path'] = '/' + cookie["_saml_idp"] = url_encoded_space_separated_b64_idp_string + cookie["_saml_idp"]["path"] = "/" # Use the domain from configuration if present else use the domain # from the base URL for the front end. domain = urlparse(self.base_url).netloc - if isinstance(self.config['common_domain_cookie'], dict): - if 'domain' in self.config['common_domain_cookie']: - domain = self.config['common_domain_cookie']['domain'] + if isinstance(self.config["common_domain_cookie"], dict): + if "domain" in self.config["common_domain_cookie"]: + domain = self.config["common_domain_cookie"]["domain"] # Ensure that the domain begins with a '.' - if domain[0] != '.': - domain = '.' + domain + if domain[0] != ".": + domain = "." + domain - cookie['_saml_idp']['domain'] = domain - cookie['_saml_idp']['secure'] = True + cookie["_saml_idp"]["domain"] = domain + cookie["_saml_idp"]["secure"] = True # Set the cookie. msg = "Setting common domain cookie with {}".format(cookie.output()) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.debug(logline) - http_args['headers'].append(tuple(cookie.output().split(": ", 1))) + http_args["headers"].append(tuple(cookie.output().split(": ", 1))) def _build_idp_config_endpoints(self, config, providers): """ @@ -657,14 +722,17 @@ def _build_idp_config_endpoints(self, config, providers): for func, endpoint in self.endpoints[endp_category].items(): for provider in providers: _endpoint = "{base}/{provider}/{endpoint}".format( - base=self.base_url, provider=provider, endpoint=endpoint) + base=self.base_url, provider=provider, endpoint=endpoint + ) idp_endpoints.append((_endpoint, func)) config["service"]["idp"]["endpoints"][endp_category] = idp_endpoints return config def _get_sp_display_name(self, idp, entity_id): - extensions = idp.metadata.extension(entity_id, "spsso_descriptor", "{}&UIInfo".format(UI_NAMESPACE)) + extensions = idp.metadata.extension( + entity_id, "spsso_descriptor", "{}&UIInfo".format(UI_NAMESPACE) + ) if not extensions: return None @@ -700,8 +768,11 @@ def _load_endpoints_to_config(self, provider, target_entity_id, config=None): idp_endpoints = [] for binding, path in endpoint.items(): url = "{base}/{provider}/{target_id}/{path}".format( - base=self.base_url, provider=provider, - target_id=target_entity_id, path=path) + base=self.base_url, + provider=provider, + target_id=target_entity_id, + path=path, + ) idp_endpoints.append((url, binding)) idp_conf["service"]["idp"]["endpoints"][service] = idp_endpoints return idp_conf @@ -718,7 +789,9 @@ def _load_idp_dynamic_endpoints(self, context): :return: An idp server """ target_entity_id = context.target_entity_id_from_path() - idp_conf_file = self._load_endpoints_to_config(context.target_backend, target_entity_id) + idp_conf_file = self._load_endpoints_to_config( + context.target_backend, target_entity_id + ) idp_config = IdPConfig().load(idp_conf_file) return Server(config=idp_config) @@ -734,7 +807,9 @@ def _load_idp_dynamic_entity_id(self, state): """ # Change the idp entity id dynamically idp_config_file = copy.deepcopy(self.idp_config) - idp_config_file["entityid"] = "{}/{}".format(self.idp_config["entityid"], state[self.name]["target_entity_id"]) + idp_config_file["entityid"] = "{}/{}".format( + self.idp_config["entityid"], state[self.name]["target_entity_id"] + ) idp_config = IdPConfig().load(idp_config_file) return Server(config=idp_config) @@ -807,7 +882,9 @@ def _register_endpoints(self, providers): url_map.append( ( r"(^{})/\S+/{}".format(valid_providers, parsed_endp.path), - functools.partial(self.handle_authn_request, binding_in=binding) + functools.partial( + self.handle_authn_request, binding_in=binding + ), ) ) @@ -819,19 +896,24 @@ class SAMLVirtualCoFrontend(SAMLFrontend): Frontend module that exposes multiple virtual SAML identity providers, each representing a collaborative organization or CO. """ - KEY_CO = 'collaborative_organizations' - KEY_CO_NAME = 'co_name' - KEY_CO_ENTITY_ID = 'co_entity_id' - KEY_CO_ATTRIBUTES = 'co_static_saml_attributes' - KEY_CO_ATTRIBUTE_SCOPE = 'co_attribute_scope' - KEY_CONTACT_PERSON = 'contact_person' - KEY_ENCODEABLE_NAME = 'encodeable_name' - KEY_ORGANIZATION = 'organization' - KEY_ORGANIZATION_KEYS = ['display_name', 'name', 'url'] - - def __init__(self, auth_req_callback_func, internal_attributes, config, base_url, name): + + KEY_CO = "collaborative_organizations" + KEY_CO_NAME = "co_name" + KEY_CO_ENTITY_ID = "co_entity_id" + KEY_CO_ATTRIBUTES = "co_static_saml_attributes" + KEY_CO_ATTRIBUTE_SCOPE = "co_attribute_scope" + KEY_CONTACT_PERSON = "contact_person" + KEY_ENCODEABLE_NAME = "encodeable_name" + KEY_ORGANIZATION = "organization" + KEY_ORGANIZATION_KEYS = ["display_name", "name", "url"] + + def __init__( + self, auth_req_callback_func, internal_attributes, config, base_url, name + ): self.has_multiple_backends = False - super().__init__(auth_req_callback_func, internal_attributes, config, base_url, name) + super().__init__( + auth_req_callback_func, internal_attributes, config, base_url, name + ) def handle_authn_request(self, context, binding_in): """ @@ -860,8 +942,7 @@ def handle_authn_response(self, context, internal_response): return self._handle_authn_response(context, internal_response) def _handle_authn_response(self, context, internal_response): - """ - """ + """ """ # Using the context of the current request and saved state from the # authentication request dynamically create an IdP instance. idp = self._create_co_virtual_idp(context) @@ -880,7 +961,7 @@ def _handle_authn_response(self, context, internal_response): else: attributes[attribute] = [value] except TypeError: - attributes[attribute] = [value] + attributes[attribute] = [value] # Handle the authentication response. return super()._handle_authn_response(context, internal_response, idp) @@ -897,13 +978,12 @@ def _create_state_data(self, context, resp_args, relay_state): """ state = super()._create_state_data(context, resp_args, relay_state) state[self.KEY_CO_NAME] = context.get_decoration(self.KEY_CO_NAME) - state[self.KEY_CO_ENTITY_ID] = context.get_decoration( - self.KEY_CO_ENTITY_ID) + state[self.KEY_CO_ENTITY_ID] = context.get_decoration(self.KEY_CO_ENTITY_ID) co_config = self._get_co_config(context) state[self.KEY_CO_ATTRIBUTE_SCOPE] = co_config.get( - self.KEY_CO_ATTRIBUTE_SCOPE, - None) + self.KEY_CO_ATTRIBUTE_SCOPE, None + ) return state @@ -992,10 +1072,11 @@ def _add_endpoints_to_config(self, config, co_name, backend_name): idp_endpoints = [] for binding, path in endpoint.items(): url = "{base}/{backend}/{co_name}/{path}".format( - base=self.base_url, - backend=backend_name, - co_name=quote_plus(co_name), - path=path) + base=self.base_url, + backend=backend_name, + co_name=quote_plus(co_name), + path=path, + ) mapping = (url, binding) idp_endpoints.append(mapping) @@ -1028,19 +1109,19 @@ def _add_entity_id(self, config, co_name, backend_name): :return: config with updated entity ID """ - base_entity_id = config['entityid'] + base_entity_id = config["entityid"] # If not using template for entityId and does not has multiple backends, then for backward compatibility append co_name at end if "" not in base_entity_id and not self.has_multiple_backends: base_entity_id = "{}/{}".format(base_entity_id, "") replace = [ ("", quote_plus(backend_name)), - ("", quote_plus(co_name)) + ("", quote_plus(co_name)), ] for _replace in replace: base_entity_id = base_entity_id.replace(_replace[0], _replace[1]) - config['entityid'] = base_entity_id + config["entityid"] = base_entity_id return config @@ -1061,8 +1142,7 @@ def _overlay_for_saml_metadata(self, config, co_name): """ all_co_configs = self.config[self.KEY_CO] co_config = next( - item for item in all_co_configs - if item[self.KEY_ENCODEABLE_NAME] == co_name + item for item in all_co_configs if item[self.KEY_ENCODEABLE_NAME] == co_name ) key = self.KEY_ORGANIZATION @@ -1088,8 +1168,7 @@ def _co_names_from_config(self): :return: list of CO names """ - co_names = [co[self.KEY_ENCODEABLE_NAME] for - co in self.config[self.KEY_CO]] + co_names = [co[self.KEY_ENCODEABLE_NAME] for co in self.config[self.KEY_CO]] return co_names @@ -1113,9 +1192,10 @@ def _create_co_virtual_idp(self, context, co_name=None): # endpoints is relaxed. co_names = self._co_names_from_config() if co_name not in co_names: - msg = "CO {} not in configured list of COs {}".format(co_name, - co_names) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + msg = "CO {} not in configured list of COs {}".format(co_name, co_names) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.warn(logline) raise SATOSAError(msg) # Make a copy of the general IdP config that we will then overwrite @@ -1123,11 +1203,9 @@ def _create_co_virtual_idp(self, context, co_name=None): # and the entityID for the CO virtual IdP. backend_name = context.target_backend idp_config = copy.deepcopy(self.idp_config) - idp_config = self._add_endpoints_to_config( - idp_config, co_name, backend_name - ) + idp_config = self._add_endpoints_to_config(idp_config, co_name, backend_name) idp_config = self._add_entity_id(idp_config, co_name, backend_name) - context.decorate(self.KEY_CO_ENTITY_ID, idp_config['entityid']) + context.decorate(self.KEY_CO_ENTITY_ID, idp_config["entityid"]) # Use the overwritten IdP config to generate a pysaml2 config object # and from it a server object. @@ -1165,17 +1243,22 @@ def _register_endpoints(self, backend_names): all_entity_ids = [] for backend_name in backend_names: for co_name in co_names: - all_entity_ids.append(self._add_entity_id(copy.deepcopy(self.idp_config), co_name, backend_name)['entityid']) + all_entity_ids.append( + self._add_entity_id( + copy.deepcopy(self.idp_config), co_name, backend_name + )["entityid"] + ) if len(all_entity_ids) != len(set(all_entity_ids)): - raise ValueError("Duplicate entities ids would be created for co-frontends, please make sure to make entity ids unique. " - "You can use and to achieve it. See example yaml file.") + raise ValueError( + "Duplicate entities ids would be created for co-frontends, please make sure to make entity ids unique. " + "You can use and to achieve it. See example yaml file." + ) # Create a regex pattern that will match any of the CO names. We # escape special characters like '+' and '.' that are valid # characters in an URL encoded string. - url_encoded_co_names = [re.escape(quote_plus(name)) for name in - co_names] + url_encoded_co_names = [re.escape(quote_plus(name)) for name in co_names] co_name_pattern = "|".join(url_encoded_co_names) # Create a regex pattern that will match any of the backend names. @@ -1209,15 +1292,15 @@ def _register_endpoints(self, backend_names): # a regex that will match and that includes a pattern for # matching the URL encoded CO name. regex_pattern = "(^{})/({})/{}".format( - backend_url_pattern, - co_name_pattern, - endpoint_path) + backend_url_pattern, co_name_pattern, endpoint_path + ) logline = "Created URL regex {}".format(regex_pattern) logger.debug(logline) # Map the regex pattern to a callable. - the_callable = functools.partial(self.handle_authn_request, - binding_in=binding) + the_callable = functools.partial( + self.handle_authn_request, binding_in=binding + ) logger.debug("Created callable {}".format(the_callable)) mapping = (regex_pattern, the_callable) @@ -1228,12 +1311,18 @@ def _register_endpoints(self, backend_names): if self.expose_entityid_endpoint(): for backend_name in backend_names: for co_name in co_names: - idp_config = self._add_entity_id(copy.deepcopy(self.idp_config), co_name, backend_name) - entity_id = idp_config['entityid'] - logger.debug("Exposing frontend entity endpoint = {}".format(entity_id)) + idp_config = self._add_entity_id( + copy.deepcopy(self.idp_config), co_name, backend_name + ) + entity_id = idp_config["entityid"] + logger.debug( + "Exposing frontend entity endpoint = {}".format(entity_id) + ) parsed_entity_id = urlparse(entity_id) metadata_endpoint = "^{0}".format(parsed_entity_id.path[1:]) - the_callable = functools.partial(self._metadata_endpoint, co_name=co_name) + the_callable = functools.partial( + self._metadata_endpoint, co_name=co_name + ) url_to_callable_mappings.append((metadata_endpoint, the_callable)) return url_to_callable_mappings @@ -1250,4 +1339,4 @@ def _metadata_endpoint(self, context, co_name): # Using the context of the current request and saved state from the # authentication request dynamically create an IdP instance. self.idp = self._create_co_virtual_idp(context, co_name=co_name) - return super()._metadata_endpoint(context=context); + return super()._metadata_endpoint(context=context) diff --git a/src/satosa/metadata_creation/saml_metadata.py b/src/satosa/metadata_creation/saml_metadata.py index f88bbaaec..ecf081f43 100644 --- a/src/satosa/metadata_creation/saml_metadata.py +++ b/src/satosa/metadata_creation/saml_metadata.py @@ -3,21 +3,27 @@ from collections import defaultdict from saml2.config import Config -from saml2.metadata import entity_descriptor, entities_descriptor, sign_entity_descriptor +from saml2.metadata import ( + entities_descriptor, + entity_descriptor, + sign_entity_descriptor, +) from saml2.time_util import in_a_while from saml2.validate import valid_instance from ..backends.saml2 import SAMLBackend -from ..frontends.saml2 import SAMLFrontend -from ..frontends.saml2 import SAMLMirrorFrontend -from ..frontends.saml2 import SAMLVirtualCoFrontend -from ..plugin_loader import load_frontends, load_backends +from ..frontends.saml2 import SAMLFrontend, SAMLMirrorFrontend, SAMLVirtualCoFrontend +from ..plugin_loader import load_backends, load_frontends logger = logging.getLogger(__name__) def _create_entity_descriptor(entity_config): - cnf = entity_config if isinstance(entity_config, Config) else Config().load(copy.deepcopy(entity_config)) + cnf = ( + entity_config + if isinstance(entity_config, Config) + else Config().load(copy.deepcopy(entity_config)) + ) return entity_descriptor(cnf) @@ -28,12 +34,16 @@ def _create_backend_metadata(backend_modules): if isinstance(plugin_module, SAMLBackend): logline = "Generating SAML backend '{}' metadata".format(plugin_module.name) logger.info(logline) - backend_metadata[plugin_module.name] = [_create_entity_descriptor(plugin_module.sp.config)] + backend_metadata[plugin_module.name] = [ + _create_entity_descriptor(plugin_module.sp.config) + ] return backend_metadata -def _create_mirrored_entity_config(frontend_instance, target_metadata_info, backend_name): +def _create_mirrored_entity_config( + frontend_instance, target_metadata_info, backend_name +): def _merge_dicts(a, b): for key, value in b.items(): if key in ["organization", "contact_person"]: @@ -46,12 +56,17 @@ def _merge_dicts(a, b): return a - merged_conf = _merge_dicts(copy.deepcopy(frontend_instance.config["idp_config"]), target_metadata_info) - full_config = frontend_instance._load_endpoints_to_config(backend_name, target_metadata_info["entityid"], - config=merged_conf) + merged_conf = _merge_dicts( + copy.deepcopy(frontend_instance.config["idp_config"]), target_metadata_info + ) + full_config = frontend_instance._load_endpoints_to_config( + backend_name, target_metadata_info["entityid"], config=merged_conf + ) proxy_entity_id = frontend_instance.config["idp_config"]["entityid"] - full_config["entityid"] = "{}/{}".format(proxy_entity_id, target_metadata_info["entityid"]) + full_config["entityid"] = "{}/{}".format( + proxy_entity_id, target_metadata_info["entityid"] + ) return full_config @@ -68,7 +83,10 @@ def _create_frontend_metadata(frontend_modules, backend_modules): meta_desc = backend.get_metadata_desc() for desc in meta_desc: entity_desc = _create_entity_descriptor( - _create_mirrored_entity_config(frontend, desc.to_dict(), backend.name)) + _create_mirrored_entity_config( + frontend, desc.to_dict(), backend.name + ) + ) frontend_metadata[frontend.name].append(entity_desc) elif isinstance(frontend, SAMLVirtualCoFrontend): @@ -79,15 +97,20 @@ def _create_frontend_metadata(frontend_modules, backend_modules): logline = "Creating metadata for CO {}".format(co_name) logger.info(logline) idp_config = copy.deepcopy(frontend.config["idp_config"]) - idp_config = frontend._add_endpoints_to_config(idp_config, co_name, backend.name) - idp_config = frontend._add_entity_id(idp_config, co_name, backend.name) - idp_config = frontend._overlay_for_saml_metadata(idp_config, co_name) + idp_config = frontend._add_endpoints_to_config( + idp_config, co_name, backend.name + ) + idp_config = frontend._add_entity_id( + idp_config, co_name, backend.name + ) + idp_config = frontend._overlay_for_saml_metadata( + idp_config, co_name + ) entity_desc = _create_entity_descriptor(idp_config) frontend_metadata[frontend.name].append(entity_desc) elif isinstance(frontend, SAMLFrontend): - frontend.register_endpoints([backend.name for - backend in backend_modules]) + frontend.register_endpoints([backend.name for backend in backend_modules]) entity_desc = _create_entity_descriptor(frontend.idp_config) frontend_metadata[frontend.name].append(entity_desc) @@ -104,10 +127,22 @@ def create_entity_descriptors(satosa_config): :type satosa_config: satosa.satosa_config.SATOSAConfig :rtype: Tuple[str, str] """ - frontend_modules = load_frontends(satosa_config, None, satosa_config["INTERNAL_ATTRIBUTES"]) - backend_modules = load_backends(satosa_config, None, satosa_config["INTERNAL_ATTRIBUTES"]) - logger.info("Loaded frontend plugins: {}".format([frontend.name for frontend in frontend_modules])) - logger.info("Loaded backend plugins: {}".format([backend.name for backend in backend_modules])) + frontend_modules = load_frontends( + satosa_config, None, satosa_config["INTERNAL_ATTRIBUTES"] + ) + backend_modules = load_backends( + satosa_config, None, satosa_config["INTERNAL_ATTRIBUTES"] + ) + logger.info( + "Loaded frontend plugins: {}".format( + [frontend.name for frontend in frontend_modules] + ) + ) + logger.info( + "Loaded backend plugins: {}".format( + [backend.name for backend in backend_modules] + ) + ) backend_metadata = _create_backend_metadata(backend_modules) frontend_metadata = _create_frontend_metadata(frontend_modules, backend_modules) @@ -115,7 +150,9 @@ def create_entity_descriptors(satosa_config): return frontend_metadata, backend_metadata -def create_signed_entities_descriptor(entity_descriptors, security_context, valid_for=None): +def create_signed_entities_descriptor( + entity_descriptors, security_context, valid_for=None +): """ :param entity_descriptors: the entity descriptors to put in in an EntitiesDescriptor tag and sign :param security_context: security context for the signature @@ -126,15 +163,23 @@ def create_signed_entities_descriptor(entity_descriptors, security_context, vali :type security_context: saml2.sigver.SecurityContext :type valid_for: Optional[int] """ - entities_desc, xmldoc = entities_descriptor(entity_descriptors, valid_for=valid_for, name=None, ident=None, - sign=True, secc=security_context) + entities_desc, xmldoc = entities_descriptor( + entity_descriptors, + valid_for=valid_for, + name=None, + ident=None, + sign=True, + secc=security_context, + ) if not valid_instance(entities_desc): raise ValueError("Could not construct valid EntitiesDescriptor tag") return xmldoc -def create_signed_entity_descriptor(entity_descriptor, security_context, valid_for=None): +def create_signed_entity_descriptor( + entity_descriptor, security_context, valid_for=None +): """ :param entity_descriptor: the entity descriptor to sign :param security_context: security context for the signature @@ -148,7 +193,9 @@ def create_signed_entity_descriptor(entity_descriptor, security_context, valid_f if valid_for: entity_descriptor.valid_until = in_a_while(hours=valid_for) - entity_desc, xmldoc = sign_entity_descriptor(entity_descriptor, None, security_context) + entity_desc, xmldoc = sign_entity_descriptor( + entity_descriptor, None, security_context + ) if not valid_instance(entity_desc): raise ValueError("Could not construct valid EntityDescriptor tag") diff --git a/src/satosa/micro_services/account_linking.py b/src/satosa/micro_services/account_linking.py index 7305c3d79..39d0352bf 100644 --- a/src/satosa/micro_services/account_linking.py +++ b/src/satosa/micro_services/account_linking.py @@ -5,15 +5,16 @@ import logging import requests -from jwkest.jwk import rsa_load, RSAKey +from jwkest.jwk import RSAKey, rsa_load from jwkest.jws import JWS +import satosa.logging_util as lu from satosa.internal import InternalData + from ..exception import SATOSAAuthenticationError from ..micro_services.base import ResponseMicroService from ..response import Redirect -import satosa.logging_util as lu logger = logging.getLogger(__name__) @@ -30,7 +31,9 @@ def __init__(self, config, *args, **kwargs): super().__init__(*args, **kwargs) self.api_url = config["api_url"] self.redirect_url = config["redirect_url"] - self.signing_key = RSAKey(key=rsa_load(config["sign_key"]), use="sig", alg="RS256") + self.signing_key = RSAKey( + key=rsa_load(config["sign_key"]), use="sig", alg="RS256" + ) self.endpoint = "/handle_account_linking" self.id_to_attr = config.get("id_to_attr", None) logger.info("Account linking is active") @@ -49,12 +52,18 @@ def _handle_al_response(self, context): saved_state = context.state[self.name] internal_response = InternalData.from_dict(saved_state) - #subject_id here is the linked id , not the facebook one, Figure out what to do - status_code, message = self._get_uuid(context, internal_response.auth_info.issuer, internal_response.attributes['issuer_user_id']) + # subject_id here is the linked id , not the facebook one, Figure out what to do + status_code, message = self._get_uuid( + context, + internal_response.auth_info.issuer, + internal_response.attributes["issuer_user_id"], + ) if status_code == 200: msg = "issuer/id pair is linked in AL service" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) internal_response.subject_id = message if self.id_to_attr: @@ -66,12 +75,13 @@ def _handle_al_response(self, context): # User selected not to link their accounts, so the internal.response.subject_id is based on the # issuers id/sub which is fine msg = "User selected to not link their identity in AL service" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) del context.state[self.name] return super().process(context, internal_response) - def process(self, context, internal_response): """ Manage account linking and recovery @@ -86,29 +96,38 @@ def process(self, context, internal_response): : """ - status_code, message = self._get_uuid(context, internal_response.auth_info.issuer, internal_response.subject_id) + status_code, message = self._get_uuid( + context, internal_response.auth_info.issuer, internal_response.subject_id + ) data = { "issuer": internal_response.auth_info.issuer, - "redirect_endpoint": "%s/account_linking%s" % (self.base_url, self.endpoint) + "redirect_endpoint": "%s/account_linking%s" + % (self.base_url, self.endpoint), } # Store the issuer subject_id/sub because we'll need it in handle_al_response - internal_response.attributes['issuer_user_id'] = internal_response.subject_id + internal_response.attributes["issuer_user_id"] = internal_response.subject_id if status_code == 200: msg = "issuer/id pair is linked in AL service" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) internal_response.subject_id = message - data['user_id'] = message + data["user_id"] = message if self.id_to_attr: internal_response.attributes[self.id_to_attr] = [message] else: msg = "issuer/id pair is not linked in AL service. Got a ticket" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) - data['ticket'] = message - jws = JWS(json.dumps(data), alg=self.signing_key.alg).sign_compact([self.signing_key]) + data["ticket"] = message + jws = JWS(json.dumps(data), alg=self.signing_key.alg).sign_compact( + [self.signing_key] + ) context.state[self.name] = internal_response.to_dict() return Redirect("%s/%s" % (self.redirect_url, jws)) @@ -132,22 +151,31 @@ def _get_uuid(self, context, issuer, id): data = { "idp": issuer, "id": id, - "redirect_endpoint": "%s/account_linking%s" % (self.base_url, self.endpoint) + "redirect_endpoint": "%s/account_linking%s" + % (self.base_url, self.endpoint), } - jws = JWS(json.dumps(data), alg=self.signing_key.alg).sign_compact([self.signing_key]) + jws = JWS(json.dumps(data), alg=self.signing_key.alg).sign_compact( + [self.signing_key] + ) try: request = "{}/get_id?jwt={}".format(self.api_url, jws) response = requests.get(request) except Exception as con_exc: msg = "Could not connect to account linking service" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.critical(logline) raise SATOSAAuthenticationError(context.state, msg) from con_exc if response.status_code not in [200, 404]: - msg = "Got status code '{}' from account linking service".format(response.status_code) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + msg = "Got status code '{}' from account linking service".format( + response.status_code + ) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.critical(logline) raise SATOSAAuthenticationError(context.state, msg) diff --git a/src/satosa/micro_services/attribute_authorization.py b/src/satosa/micro_services/attribute_authorization.py index 60f4afe4b..d6d9f30ec 100644 --- a/src/satosa/micro_services/attribute_authorization.py +++ b/src/satosa/micro_services/attribute_authorization.py @@ -1,8 +1,9 @@ import re -from .base import ResponseMicroService from ..exception import SATOSAAuthenticationError from ..util import get_dict_defaults +from .base import ResponseMicroService + class AttributeAuthorization(ResponseMicroService): """ @@ -56,11 +57,17 @@ def __init__(self, config, *args, **kwargs): super().__init__(*args, **kwargs) self.attribute_allow = config.get("attribute_allow", {}) self.attribute_deny = config.get("attribute_deny", {}) - self.force_attributes_presence_on_allow = config.get("force_attributes_presence_on_allow", False) - self.force_attributes_presence_on_deny = config.get("force_attributes_presence_on_deny", False) + self.force_attributes_presence_on_allow = config.get( + "force_attributes_presence_on_allow", False + ) + self.force_attributes_presence_on_deny = config.get( + "force_attributes_presence_on_deny", False + ) def _check_authz(self, context, attributes, requester, provider): - for attribute_name, attribute_filters in get_dict_defaults(self.attribute_allow, requester, provider).items(): + for attribute_name, attribute_filters in get_dict_defaults( + self.attribute_allow, requester, provider + ).items(): attr_values = attributes.get(attribute_name) if attr_values is not None: if not any( @@ -73,12 +80,18 @@ def _check_authz(self, context, attributes, requester, provider): elif self.force_attributes_presence_on_allow: raise SATOSAAuthenticationError(context.state, "Permission denied") - for attribute_name, attribute_filters in get_dict_defaults(self.attribute_deny, requester, provider).items(): + for attribute_name, attribute_filters in get_dict_defaults( + self.attribute_deny, requester, provider + ).items(): attr_values = attributes.get(attribute_name) if attr_values is not None: if any( [ - any(filter(lambda x: re.search(af, x), attributes[attribute_name])) + any( + filter( + lambda x: re.search(af, x), attributes[attribute_name] + ) + ) for af in attribute_filters ] ): @@ -87,5 +100,7 @@ def _check_authz(self, context, attributes, requester, provider): raise SATOSAAuthenticationError(context.state, "Permission denied") def process(self, context, data): - self._check_authz(context, data.attributes, data.requester, data.auth_info.issuer) + self._check_authz( + context, data.attributes, data.requester, data.auth_info.issuer + ) return super().process(context, data) diff --git a/src/satosa/micro_services/attribute_generation.py b/src/satosa/micro_services/attribute_generation.py index 907a8462d..77c77d2a5 100644 --- a/src/satosa/micro_services/attribute_generation.py +++ b/src/satosa/micro_services/attribute_generation.py @@ -1,20 +1,21 @@ import re + from chevron import render as render_mustache -from .base import ResponseMicroService from ..util import get_dict_defaults +from .base import ResponseMicroService class MustachAttrValue(object): def __init__(self, attr_name, values): self._attr_name = attr_name self._values = values - if any(['@' in v for v in values]): + if any(["@" in v for v in values]): local_parts = [] domain_parts = [] scopes = dict() for v in values: - (local_part, sep, domain_part) = v.partition('@') + (local_part, sep, domain_part) = v.partition("@") # probably not needed now... local_parts.append(local_part) domain_parts.append(domain_part) @@ -53,69 +54,69 @@ def scope(self): class AddSyntheticAttributes(ResponseMicroService): """ -A class that add generated or synthetic attributes to a response set. Attribute -generation is done using mustach (http://mustache.github.io) templates. The -following example configuration illustrates most common features: - -```yaml -module: satosa.micro_services.attribute_generation.AddSyntheticAttributes -name: AddSyntheticAttributes -config: - synthetic_attributes: - requester1: - target_provider1: - eduPersonAffiliation: member;employee - default: + A class that add generated or synthetic attributes to a response set. Attribute + generation is done using mustach (http://mustache.github.io) templates. The + following example configuration illustrates most common features: + + ```yaml + module: satosa.micro_services.attribute_generation.AddSyntheticAttributes + name: AddSyntheticAttributes + config: + synthetic_attributes: + requester1: + target_provider1: + eduPersonAffiliation: member;employee default: - schacHomeOrganization: {{eduPersonPrincipalName.scope}} - schacHomeOrganizationType: tomfoolery provider - -``` - -The use of "" and 'default' is synonymous. Attribute rules are not -overloaded or inherited. For instance a response for "requester1" -from target_provider1 in the above config will generate a (static) attribute -set of 'member' and 'employee' for the eduPersonAffiliation attribute -and nothing else. Note that synthetic attributes override existing -attributes if present. - -*Evaluating and interpreting templates* - -Attribute values are split on combinations of ';' and newline so that -a template resulting in the following text: -``` -a; -b;c -``` -results in three attribute values: 'a','b' and 'c'. Templates are -evaluated with a single context that represents the response attributes -before the microservice is processed. De-referencing the attribute -name as in '{{name}}' results in a ';'-separated list of all attribute -values. This notation is useful when you know there is only a single -attribute value in the set. - -*Special contexts* - -For treating the values as a list - eg for interating using mustach, -use the .values sub-context For instance to synthesize all first-last -name combinations do this: - -``` -{{#givenName.values}} - {{#sn.values}}{{givenName}} {{sn}}{{/sn.values}} -{{/givenName.values}} -``` - -Note that the .values sub-context behaves as if it is an iterator -over single-value context with the same key name as the original -attribute name. - -The .scope sub-context evalues to the right-hand part of any @ -sign. This is assumed to be single valued. - -The .first sub-context evalues to the first value of a context -which may be safer to use if the attribute is multivalued but -you don't care which value is used in a template. + default: + schacHomeOrganization: {{eduPersonPrincipalName.scope}} + schacHomeOrganizationType: tomfoolery provider + + ``` + + The use of "" and 'default' is synonymous. Attribute rules are not + overloaded or inherited. For instance a response for "requester1" + from target_provider1 in the above config will generate a (static) attribute + set of 'member' and 'employee' for the eduPersonAffiliation attribute + and nothing else. Note that synthetic attributes override existing + attributes if present. + + *Evaluating and interpreting templates* + + Attribute values are split on combinations of ';' and newline so that + a template resulting in the following text: + ``` + a; + b;c + ``` + results in three attribute values: 'a','b' and 'c'. Templates are + evaluated with a single context that represents the response attributes + before the microservice is processed. De-referencing the attribute + name as in '{{name}}' results in a ';'-separated list of all attribute + values. This notation is useful when you know there is only a single + attribute value in the set. + + *Special contexts* + + For treating the values as a list - eg for interating using mustach, + use the .values sub-context For instance to synthesize all first-last + name combinations do this: + + ``` + {{#givenName.values}} + {{#sn.values}}{{givenName}} {{sn}}{{/sn.values}} + {{/givenName.values}} + ``` + + Note that the .values sub-context behaves as if it is an iterator + over single-value context with the same key name as the original + attribute name. + + The .scope sub-context evalues to the right-hand part of any @ + sign. This is assumed to be single valued. + + The .first sub-context evalues to the first value of a context + which may be safer to use if the attribute is multivalued but + you don't care which value is used in a template. """ def __init__(self, config, *args, **kwargs): @@ -141,11 +142,13 @@ def _synthesize(self, attributes, requester, provider): syn_attributes[attr_name] = [ value for token in re.split("[;\n]+", render_mustache(fmt, context)) - for value in [token.strip().strip(';')] + for value in [token.strip().strip(";")] if value ] return syn_attributes def process(self, context, data): - data.attributes.update(self._synthesize(data.attributes, data.requester, data.auth_info.issuer)) + data.attributes.update( + self._synthesize(data.attributes, data.requester, data.auth_info.issuer) + ) return super().process(context, data) diff --git a/src/satosa/micro_services/attribute_modifications.py b/src/satosa/micro_services/attribute_modifications.py index bb00761b4..aa0410bd2 100644 --- a/src/satosa/micro_services/attribute_modifications.py +++ b/src/satosa/micro_services/attribute_modifications.py @@ -1,12 +1,13 @@ -import re import logging +import re -from .base import ResponseMicroService from ..context import Context from ..exception import SATOSAError +from .base import ResponseMicroService logger = logging.getLogger(__name__) + class AddStaticAttributes(ResponseMicroService): """ Add static attributes to the responses. @@ -34,17 +35,25 @@ def process(self, context, data): # apply default filters provider_filters = self.attribute_filters.get("", {}) target_provider = data.auth_info.issuer - self._apply_requester_filters(data.attributes, provider_filters, data.requester, context, target_provider) + self._apply_requester_filters( + data.attributes, provider_filters, data.requester, context, target_provider + ) # apply target provider specific filters provider_filters = self.attribute_filters.get(target_provider, {}) - self._apply_requester_filters(data.attributes, provider_filters, data.requester, context, target_provider) + self._apply_requester_filters( + data.attributes, provider_filters, data.requester, context, target_provider + ) return super().process(context, data) - def _apply_requester_filters(self, attributes, provider_filters, requester, context, target_provider): + def _apply_requester_filters( + self, attributes, provider_filters, requester, context, target_provider + ): # apply default requester filters default_requester_filters = provider_filters.get("", {}) - self._apply_filters(attributes, default_requester_filters, context, target_provider) + self._apply_filters( + attributes, default_requester_filters, context, target_provider + ) # apply requester specific filters requester_filters = provider_filters.get(requester, {}) @@ -54,39 +63,54 @@ def _apply_filters(self, attributes, attribute_filters, context, target_provider for attribute_name, attribute_filters in attribute_filters.items(): if type(attribute_filters) == str: # convert simple notation to filter list - attribute_filters = {'regexp': attribute_filters} + attribute_filters = {"regexp": attribute_filters} for filter_type, filter_value in attribute_filters.items(): - if filter_type == "regexp": filter_func = re.compile(filter_value).search elif filter_type == "shibmdscope_match_scope": mdstore = context.get_decoration(Context.KEY_METADATA_STORE) - md_scopes = list(mdstore.shibmd_scopes(target_provider,"idpsso_descriptor")) if mdstore else [] + md_scopes = ( + list( + mdstore.shibmd_scopes(target_provider, "idpsso_descriptor") + ) + if mdstore + else [] + ) filter_func = lambda v: self._shibmdscope_match_scope(v, md_scopes) elif filter_type == "shibmdscope_match_value": mdstore = context.get_decoration(Context.KEY_METADATA_STORE) - md_scopes = list(mdstore.shibmd_scopes(target_provider,"idpsso_descriptor")) if mdstore else [] + md_scopes = ( + list( + mdstore.shibmd_scopes(target_provider, "idpsso_descriptor") + ) + if mdstore + else [] + ) filter_func = lambda v: self._shibmdscope_match_value(v, md_scopes) else: raise SATOSAError("Unknown filter type") if attribute_name == "": # default filter for all attributes for attribute, values in attributes.items(): - attributes[attribute] = list(filter(filter_func, attributes[attribute])) + attributes[attribute] = list( + filter(filter_func, attributes[attribute]) + ) elif attribute_name in attributes: - attributes[attribute_name] = list(filter(filter_func, attributes[attribute_name])) + attributes[attribute_name] = list( + filter(filter_func, attributes[attribute_name]) + ) def _shibmdscope_match_value(self, value, md_scopes): for md_scope in md_scopes: - if not md_scope['regexp'] and md_scope['text'] == value: + if not md_scope["regexp"] and md_scope["text"] == value: return True - elif md_scope['regexp'] and re.fullmatch(md_scope['text'], value): + elif md_scope["regexp"] and re.fullmatch(md_scope["text"], value): return True return False def _shibmdscope_match_scope(self, value, md_scopes): - split_value = value.split('@') + split_value = value.split("@") if len(split_value) != 2: logger.info(f"Discarding invalid scoped value {value}") return False diff --git a/src/satosa/micro_services/attribute_policy.py b/src/satosa/micro_services/attribute_policy.py index 81151d0e4..9527b460a 100644 --- a/src/satosa/micro_services/attribute_policy.py +++ b/src/satosa/micro_services/attribute_policy.py @@ -26,7 +26,7 @@ def process(self, context, data): policy = self.attribute_policy.get(data.requester, {}) if "allowed" in policy: - for key in (data.attributes.keys() - set(policy["allowed"])): + for key in data.attributes.keys() - set(policy["allowed"]): del data.attributes[key] msg = "Returning data.attributes {}".format(data.attributes) diff --git a/src/satosa/micro_services/attribute_processor.py b/src/satosa/micro_services/attribute_processor.py index 7232e484e..014fdd453 100644 --- a/src/satosa/micro_services/attribute_processor.py +++ b/src/satosa/micro_services/attribute_processor.py @@ -5,14 +5,13 @@ from satosa.logging_util import satosa_logging from satosa.micro_services.base import ResponseMicroService - logger = logging.getLogger(__name__) -CONFIG_KEY_ROOT = 'process' -CONFIG_KEY_MODULE = 'module' -CONFIG_KEY_CLASSNAME = 'name' -CONFIG_KEY_ATTRIBUTE = 'attribute' -CONFIG_KEY_PROCESSORS = 'processors' +CONFIG_KEY_ROOT = "process" +CONFIG_KEY_MODULE = "module" +CONFIG_KEY_CLASSNAME = "name" +CONFIG_KEY_ATTRIBUTE = "attribute" +CONFIG_KEY_PROCESSORS = "processors" class AttributeProcessor(ResponseMicroService): @@ -38,6 +37,7 @@ class AttributeProcessor(ResponseMicroService): module: satosa.micro_services.processors.scope_processor scope: example """ + def __init__(self, config, *args, **kwargs): super().__init__(*args, **kwargs) self.config = config diff --git a/src/satosa/micro_services/base.py b/src/satosa/micro_services/base.py index 084cbea76..607c0dbbc 100644 --- a/src/satosa/micro_services/base.py +++ b/src/satosa/micro_services/base.py @@ -54,6 +54,7 @@ class ResponseMicroService(MicroService): """ Base class for response micro services """ + pass @@ -61,4 +62,5 @@ class RequestMicroService(MicroService): """ Base class for request micro services """ + pass diff --git a/src/satosa/micro_services/consent.py b/src/satosa/micro_services/consent.py index a469e2189..fd564b85f 100644 --- a/src/satosa/micro_services/consent.py +++ b/src/satosa/micro_services/consent.py @@ -7,8 +7,7 @@ from base64 import urlsafe_b64encode import requests -from jwkest.jwk import RSAKey -from jwkest.jwk import rsa_load +from jwkest.jwk import RSAKey, rsa_load from jwkest.jws import JWS from requests.exceptions import ConnectionError @@ -17,7 +16,6 @@ from satosa.micro_services.base import ResponseMicroService from satosa.response import Redirect - logger = logging.getLogger(__name__) STATE_KEY = "CONSENT" @@ -41,7 +39,9 @@ def __init__(self, config, internal_attributes, *args, **kwargs): if "user_id_to_attr" in internal_attributes: self.locked_attr = internal_attributes["user_id_to_attr"] - self.signing_key = RSAKey(key=rsa_load(config["sign_key"]), use="sig", alg="RS256") + self.signing_key = RSAKey( + key=rsa_load(config["sign_key"]), use="sig", alg="RS256" + ) self.endpoint = "/handle_consent" logger.info("Consent flow is active") @@ -58,30 +58,41 @@ def _handle_consent_response(self, context): saved_resp = consent_state["internal_resp"] internal_response = InternalData.from_dict(saved_resp) - hash_id = self._get_consent_id(internal_response.requester, internal_response.subject_id, - internal_response.attributes) + hash_id = self._get_consent_id( + internal_response.requester, + internal_response.subject_id, + internal_response.attributes, + ) try: consent_attributes = self._verify_consent(hash_id) except ConnectionError as e: msg = "Consent service is not reachable, no consent given." - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline, exc_info=e) # Send an internal_response without any attributes consent_attributes = None if consent_attributes is None: msg = "Consent was NOT given" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) # If consent was not given, then don't send any attributes consent_attributes = [] else: msg = "Consent was given" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) - internal_response.attributes = self._filter_attributes(internal_response.attributes, consent_attributes) + internal_response.attributes = self._filter_attributes( + internal_response.attributes, consent_attributes + ) return self._end_consent(context, internal_response) def _approve_new_consent(self, context, internal_response, id_hash): @@ -96,13 +107,15 @@ def _approve_new_consent(self, context, internal_response, id_hash): } if self.locked_attr: consent_args["locked_attrs"] = [self.locked_attr] - if 'requester_logo' in context.state[STATE_KEY]: - consent_args["requester_logo"] = context.state[STATE_KEY]['requester_logo'] + if "requester_logo" in context.state[STATE_KEY]: + consent_args["requester_logo"] = context.state[STATE_KEY]["requester_logo"] try: ticket = self._consent_registration(consent_args) except (ConnectionError, UnexpectedResponseError) as e: msg = "Consent request failed, no consent given: {}".format(str(e)) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) # Send an internal_response without any attributes internal_response.attributes = {} @@ -135,7 +148,9 @@ def process(self, context, internal_response): consent_attributes = self._verify_consent(id_hash) except requests.exceptions.ConnectionError as e: msg = "Consent service is not reachable, no consent is given." - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline, exc_info=e) # Send an internal_response without any attributes internal_response.attributes = {} @@ -144,9 +159,13 @@ def process(self, context, internal_response): # Previous consent was given if consent_attributes is not None: msg = "Previous consent was given" - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) - internal_response.attributes = self._filter_attributes(internal_response.attributes, consent_attributes) + internal_response.attributes = self._filter_attributes( + internal_response.attributes, consent_attributes + ) return self._end_consent(context, internal_response) # No previous consent, request consent by user @@ -175,7 +194,9 @@ def _get_consent_id(self, requester, user_id, filtered_attr): _hash_value = "".join(sorted(filtered_attr[key])) hash_str += key + _hash_value id_string = "%s%s%s" % (requester, user_id, hash_str) - return urlsafe_b64encode(hashlib.sha512(id_string.encode("utf-8")).hexdigest().encode("utf-8")).decode("utf-8") + return urlsafe_b64encode( + hashlib.sha512(id_string.encode("utf-8")).hexdigest().encode("utf-8") + ).decode("utf-8") def _consent_registration(self, consent_args): """ @@ -187,12 +208,16 @@ def _consent_registration(self, consent_args): :param consent_args: All necessary parameters for the consent request :return: Ticket received from the consent service """ - jws = JWS(json.dumps(consent_args), alg=self.signing_key.alg).sign_compact([self.signing_key]) + jws = JWS(json.dumps(consent_args), alg=self.signing_key.alg).sign_compact( + [self.signing_key] + ) request = "{}/creq/{}".format(self.api_url, jws) res = requests.get(request) if res.status_code != 200: - raise UnexpectedResponseError("Consent service error: %s %s", res.status_code, res.text) + raise UnexpectedResponseError( + "Consent service error: %s %s", res.status_code, res.text + ) return res.text diff --git a/src/satosa/micro_services/custom_logging.py b/src/satosa/micro_services/custom_logging.py index 14d435d8f..2469c70dc 100644 --- a/src/satosa/micro_services/custom_logging.py +++ b/src/satosa/micro_services/custom_logging.py @@ -7,8 +7,8 @@ import logging import satosa.logging_util as lu -from .base import ResponseMicroService +from .base import ResponseMicroService logger = logging.getLogger(__name__) @@ -17,6 +17,7 @@ class CustomLoggingService(ResponseMicroService): """ Use context and data object to create custom log output """ + logprefix = "CUSTOM_LOGGING_SERVICE:" def __init__(self, config, *args, **kwargs): @@ -37,11 +38,15 @@ def process(self, context, data): # Find the entityID for the SP that initiated the flow and target IdP try: - spEntityID = context.state.state_dict['SATOSA_BASE']['requester'] + spEntityID = context.state.state_dict["SATOSA_BASE"]["requester"] idpEntityID = data.auth_info.issuer except KeyError: - msg = "{} Unable to determine the entityID's for the IdP or SP".format(logprefix) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + msg = "{} Unable to determine the entityID's for the IdP or SP".format( + logprefix + ) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) return super().process(context, data) @@ -54,45 +59,52 @@ def process(self, context, data): # Obtain configuration details from the per-SP configuration or the default configuration try: - if 'log_target' in config: - log_target = config['log_target'] + if "log_target" in config: + log_target = config["log_target"] else: - log_target = self.config['log_target'] + log_target = self.config["log_target"] - if 'attrs' in config: - attrs = config['attrs'] + if "attrs" in config: + attrs = config["attrs"] else: - attrs = self.config['attrs'] - + attrs = self.config["attrs"] except KeyError as err: msg = "{} Configuration '{}' is missing".format(logprefix, err) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) return super().process(context, data) try: msg = "{} Using context {}".format(logprefix, context) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) msg = "{} Using data {}".format(logprefix, data.to_dict()) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) # Open log_target file msg = "{} Opening log_target file {}".format(logprefix, log_target) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) - loghandle = open(log_target,"a") + loghandle = open(log_target, "a") # This is where the logging magic happens log = {} - log['router'] = context.state.state_dict['ROUTER'] - log['timestamp'] = data.auth_info.timestamp - log['sessionid'] = context.state.state_dict['SESSION_ID'] - log['idp'] = idpEntityID - log['sp'] = spEntityID - log['attr'] = { key: data.to_dict()['attr'].get(key) for key in attrs } + log["router"] = context.state.state_dict["ROUTER"] + log["timestamp"] = data.auth_info.timestamp + log["sessionid"] = context.state.state_dict["SESSION_ID"] + log["idp"] = idpEntityID + log["sp"] = spEntityID + log["attr"] = {key: data.to_dict()["attr"].get(key) for key in attrs} print(json.dumps(log), file=loghandle, end="\n") @@ -104,7 +116,9 @@ def process(self, context, data): else: msg = "{} Closing log_target file".format(logprefix) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) # Close log_target file diff --git a/src/satosa/micro_services/custom_routing.py b/src/satosa/micro_services/custom_routing.py index 5706ce9aa..f9704bf92 100644 --- a/src/satosa/micro_services/custom_routing.py +++ b/src/satosa/micro_services/custom_routing.py @@ -4,16 +4,15 @@ from satosa.context import Context from satosa.internal import InternalData +from ..exception import SATOSAConfigurationError, SATOSAError from .base import RequestMicroService -from ..exception import SATOSAConfigurationError -from ..exception import SATOSAError - logger = logging.getLogger(__name__) class CustomRoutingError(SATOSAError): """SATOSA exception raised by CustomRouting rules""" + pass @@ -22,7 +21,7 @@ class DecideBackendByTargetIssuer(RequestMicroService): Select target backend based on the target issuer. """ - def __init__(self, config:dict, *args, **kwargs): + def __init__(self, config: dict, *args, **kwargs): """ Constructor. @@ -31,26 +30,23 @@ def __init__(self, config:dict, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.target_mapping = config['target_mapping'] - self.default_backend = config['default_backend'] + self.target_mapping = config["target_mapping"] + self.default_backend = config["default_backend"] - def process(self, context:Context, data:InternalData): + def process(self, context: Context, data: InternalData): """Set context.target_backend based on the target issuer""" target_issuer = context.get_decoration(Context.KEY_TARGET_ENTITYID) if not target_issuer: - logger.info('skipping backend decision because no target_issuer was found') + logger.info("skipping backend decision because no target_issuer was found") return super().process(context, data) - target_backend = ( - self.target_mapping.get(target_issuer) - or self.default_backend - ) + target_backend = self.target_mapping.get(target_issuer) or self.default_backend report = { - 'msg': 'decided target backend by target issuer', - 'target_issuer': target_issuer, - 'target_backend': target_backend, + "msg": "decided target backend by target issuer", + "target_issuer": target_issuer, + "target_backend": target_backend, } logger.info(report) @@ -72,8 +68,8 @@ def __init__(self, config, *args, **kwargs): :type config: Dict[str, Dict[str, str]] """ super().__init__(*args, **kwargs) - self.requester_mapping = config['requester_mapping'] - self.default_backend = config.get('default_backend') + self.requester_mapping = config["requester_mapping"] + self.default_backend = config.get("default_backend") def process(self, context, data): """ @@ -81,7 +77,9 @@ def process(self, context, data): :param context: request context :param data: the internal request """ - context.target_backend = self.requester_mapping.get(data.requester) or self.default_backend + context.target_backend = ( + self.requester_mapping.get(data.requester) or self.default_backend + ) return super().process(context, data) @@ -93,14 +91,19 @@ class DecideIfRequesterIsAllowed(RequestMicroService): Currently, a target entityid is set only when the `SAMLMirrorFrontend` is used. """ + def __init__(self, config, *args, **kwargs): super().__init__(*args, **kwargs) for target_entity, rules in config["rules"].items(): - conflicting_rules = set(rules.get("deny", [])).intersection(rules.get("allow", [])) + conflicting_rules = set(rules.get("deny", [])).intersection( + rules.get("allow", []) + ) if conflicting_rules: - raise SATOSAConfigurationError("Conflicting requester rules for DecideIfRequesterIsAllowed," - "{} is both denied and allowed".format(conflicting_rules)) + raise SATOSAConfigurationError( + "Conflicting requester rules for DecideIfRequesterIsAllowed," + "{} is both denied and allowed".format(conflicting_rules) + ) # target entity id is base64 url encoded to make it usable in URLs, # so we convert the rules the use those encoded entity id's instead @@ -121,28 +124,36 @@ def process(self, context, data): target_specific_rules = self.rules.get(target_entity_id) # default to allowing everything if there are no specific rules if not target_specific_rules: - logger.debug("Requester '{}' allowed by default to target entity '{}' due to no entity specific rules".format( - data.requester, target_entity_id - )) + logger.debug( + "Requester '{}' allowed by default to target entity '{}' due to no entity specific rules".format( + data.requester, target_entity_id + ) + ) return super().process(context, data) # deny rules takes precedence deny_rules = target_specific_rules.get("deny", []) if data.requester in deny_rules: - logger.debug("Requester '{}' is not allowed by target entity '{}' due to deny rules '{}'".format( - data.requester, target_entity_id, deny_rules - )) + logger.debug( + "Requester '{}' is not allowed by target entity '{}' due to deny rules '{}'".format( + data.requester, target_entity_id, deny_rules + ) + ) raise SATOSAError("Requester is not allowed by target provider") allow_rules = target_specific_rules.get("allow", []) allow_all = "*" in allow_rules if data.requester in allow_rules or allow_all: - logger.debug("Requester '{}' allowed by target entity '{}' due to allow rules '{}".format( - data.requester, target_entity_id, allow_rules - )) + logger.debug( + "Requester '{}' allowed by target entity '{}' due to allow rules '{}".format( + data.requester, target_entity_id, allow_rules + ) + ) return super().process(context, data) - logger.debug("Requester '{}' is not allowed by target entity '{}' due to final deny all rule in '{}'".format( - data.requester, target_entity_id, deny_rules - )) + logger.debug( + "Requester '{}' is not allowed by target entity '{}' due to final deny all rule in '{}'".format( + data.requester, target_entity_id, deny_rules + ) + ) raise SATOSAError("Requester is not allowed by target provider") diff --git a/src/satosa/micro_services/disco.py b/src/satosa/micro_services/disco.py index 274f18780..9fbbde71b 100644 --- a/src/satosa/micro_services/disco.py +++ b/src/satosa/micro_services/disco.py @@ -1,8 +1,8 @@ from satosa.context import Context from satosa.internal import InternalData -from .base import RequestMicroService from ..exception import SATOSAError +from .base import RequestMicroService class DiscoToTargetIssuerError(SATOSAError): @@ -10,17 +10,17 @@ class DiscoToTargetIssuerError(SATOSAError): class DiscoToTargetIssuer(RequestMicroService): - def __init__(self, config:dict, *args, **kwargs): + def __init__(self, config: dict, *args, **kwargs): super().__init__(*args, **kwargs) - self.disco_endpoints = config['disco_endpoints'] + self.disco_endpoints = config["disco_endpoints"] if not isinstance(self.disco_endpoints, list) or not self.disco_endpoints: - raise DiscoToTargetIssuerError('disco_endpoints must be a list of str') + raise DiscoToTargetIssuerError("disco_endpoints must be a list of str") - def process(self, context:Context, data:InternalData): + def process(self, context: Context, data: InternalData): context.state[self.name] = { - 'target_frontend': context.target_frontend, - 'internal_data': data.to_dict(), + "target_frontend": context.target_frontend, + "internal_data": data.to_dict(), } return super().process(context, data) @@ -39,18 +39,15 @@ def register_endpoints(self): [(regexp, Callable[[satosa.context.Context], satosa.response.Response]), ...] """ - return [ - (path , self._handle_disco_response) - for path in self.disco_endpoints - ] + return [(path, self._handle_disco_response) for path in self.disco_endpoints] - def _handle_disco_response(self, context:Context): - target_issuer = context.request.get('entityID') + def _handle_disco_response(self, context: Context): + target_issuer = context.request.get("entityID") if not target_issuer: - raise DiscoToTargetIssuerError('no valid entity_id in the disco response') + raise DiscoToTargetIssuerError("no valid entity_id in the disco response") - target_frontend = context.state.get(self.name, {}).get('target_frontend') - data_serialized = context.state.get(self.name, {}).get('internal_data', {}) + target_frontend = context.state.get(self.name, {}).get("target_frontend") + data_serialized = context.state.get(self.name, {}).get("internal_data", {}) data = InternalData.from_dict(data_serialized) context.target_frontend = target_frontend diff --git a/src/satosa/micro_services/hasher.py b/src/satosa/micro_services/hasher.py index 111ef8a99..e38750204 100644 --- a/src/satosa/micro_services/hasher.py +++ b/src/satosa/micro_services/hasher.py @@ -1,7 +1,6 @@ import satosa.util as util from satosa.micro_services.base import ResponseMicroService - CONFIG_KEY_SALT = "salt" CONFIG_KEY_ALG = "alg" CONFIG_KEY_SUBJID = "subject_id" @@ -77,9 +76,7 @@ def _init_config(self, config): defaults.update(config.get("", {})) if not defaults.get(CONFIG_KEY_SALT, None): - raise Exception( - "Required config key missing: {}".format(CONFIG_KEY_SALT) - ) + raise Exception("Required config key missing: {}".format(CONFIG_KEY_SALT)) for requester, conf in config.items(): defs = defaults.copy() diff --git a/src/satosa/micro_services/idp_hinting.py b/src/satosa/micro_services/idp_hinting.py index 90569d706..a0101be72 100644 --- a/src/satosa/micro_services/idp_hinting.py +++ b/src/satosa/micro_services/idp_hinting.py @@ -1,9 +1,7 @@ import logging +from ..exception import SATOSAConfigurationError, SATOSAError from .base import RequestMicroService -from ..exception import SATOSAConfigurationError -from ..exception import SATOSAError - logger = logging.getLogger(__name__) @@ -12,6 +10,7 @@ class IdpHintingError(SATOSAError): """ SATOSA exception raised by IdpHinting microservice """ + pass @@ -28,7 +27,7 @@ def __init__(self, config, *args, **kwargs): """ super().__init__(*args, **kwargs) try: - self.idp_hint_param_names = config['allowed_params'] + self.idp_hint_param_names = config["allowed_params"] except KeyError: raise SATOSAConfigurationError( f"{self.__class__.__name__} can't find allowed_params" diff --git a/src/satosa/micro_services/ldap_attribute_store.py b/src/satosa/micro_services/ldap_attribute_store.py index 6d61559b1..c3772e82c 100644 --- a/src/satosa/micro_services/ldap_attribute_store.py +++ b/src/satosa/micro_services/ldap_attribute_store.py @@ -16,12 +16,11 @@ import satosa.logging_util as lu from satosa.exception import SATOSAError +from satosa.frontends.saml2 import SAMLVirtualCoFrontend from satosa.micro_services.base import ResponseMicroService from satosa.response import Redirect -from satosa.frontends.saml2 import SAMLVirtualCoFrontend from satosa.routing import STATE_KEY as ROUTING_STATE_KEY - logger = logging.getLogger(__name__) KEY_FOUND_LDAP_RECORD = "ldap_attribute_store_found_record" @@ -274,13 +273,13 @@ def _ldap_connection_factory(self, config): "LDIF": ldap3.LDIF, "RESTARTABLE": ldap3.RESTARTABLE, "REUSABLE": ldap3.REUSABLE, - "MOCK_SYNC": ldap3.MOCK_SYNC + "MOCK_SYNC": ldap3.MOCK_SYNC, } client_strategy = client_strategy_map[client_strategy_string] - args = {'host': config["ldap_url"]} + args = {"host": config["ldap_url"]} if client_strategy == ldap3.MOCK_SYNC: - args['get_info'] = ldap3.OFFLINE_SLAPD_2_4 + args["get_info"] = ldap3.OFFLINE_SLAPD_2_4 server = ldap3.Server(**args) @@ -307,7 +306,7 @@ def _ldap_connection_factory(self, config): pool_size = config["pool_size"] pool_keepalive = config["pool_keepalive"] - pool_name = ''.join(random.sample(string.ascii_lowercase, 6)) + pool_name = "".join(random.sample(string.ascii_lowercase, 6)) if client_strategy == ldap3.REUSABLE: msg = "Using pool size {}".format(pool_size) @@ -365,14 +364,11 @@ def _populate_attributes(self, config, record): for attr, values in ldap_attributes.items(): internal_attr = ldap_to_internal_map.get(attr, None) if not internal_attr and ";" in attr: - internal_attr = ldap_to_internal_map.get(attr.split(";")[0], - None) + internal_attr = ldap_to_internal_map.get(attr.split(";")[0], None) if internal_attr and values: attributes[internal_attr] = ( - values - if isinstance(values, list) - else [values] + values if isinstance(values, list) else [values] ) msg = "Recording internal attribute {} with values {}" logline = msg.format(internal_attr, attributes[internal_attr]) @@ -413,8 +409,9 @@ def process(self, context, data): entity_ids = [requester, issuer, co_entity_id, "default"] - config, entity_id = next((self.config.get(e), e) - for e in entity_ids if self.config.get(e)) + config, entity_id = next( + (self.config.get(e), e) for e in entity_ids if self.config.get(e) + ) msg = { "message": "entityID for the involved entities", diff --git a/src/satosa/micro_services/primary_identifier.py b/src/satosa/micro_services/primary_identifier.py index 1df2479eb..16d937e01 100644 --- a/src/satosa/micro_services/primary_identifier.py +++ b/src/satosa/micro_services/primary_identifier.py @@ -12,7 +12,6 @@ import satosa.micro_services.base from satosa.response import Redirect - logger = logging.getLogger(__name__) @@ -24,6 +23,7 @@ class PrimaryIdentifier(satosa.micro_services.base.ResponseMicroService): handle the error in a configured way that may be to ignore the error or redirect to an external error handling service. """ + logprefix = "PRIMARY_IDENTIFIER:" def __init__(self, config, *args, **kwargs): @@ -48,22 +48,30 @@ def constructPrimaryIdentifier(self, data, ordered_identifier_candidates): for candidate in ordered_identifier_candidates: msg = "{} Considering candidate {}".format(logprefix, candidate) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) # Get the values asserted by the IdP for the configured list of attribute names for this candidate # and substitute None if the IdP did not assert any value for a configured attribute. - values = [ attributes.get(attribute_name, [None])[0] for attribute_name in candidate['attribute_names'] if attribute_name != 'name_id' ] + values = [ + attributes.get(attribute_name, [None])[0] + for attribute_name in candidate["attribute_names"] + if attribute_name != "name_id" + ] msg = "{} Found candidate values {}".format(logprefix, values) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) # If one of the configured attribute names is name_id then if there is also a configured # name_id_format add the value for the NameID of that format if it was asserted by the IdP # or else add the value None. - if 'name_id' in candidate['attribute_names']: + if "name_id" in candidate["attribute_names"]: candidate_nameid_value = None - candidate_name_id_format = candidate.get('name_id_format') + candidate_name_id_format = candidate.get("name_id_format") name_id_value = data.subject_id name_id_format = data.subject_type if ( @@ -72,7 +80,9 @@ def constructPrimaryIdentifier(self, data, ordered_identifier_candidates): and candidate_name_id_format == name_id_format ): msg = "{} IdP asserted NameID {}".format(logprefix, name_id_value) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) candidate_nameid_value = name_id_value @@ -85,39 +95,47 @@ def constructPrimaryIdentifier(self, data, ordered_identifier_candidates): msg = "{} Added NameID {} to candidate values".format( logprefix, candidate_nameid_value ) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) values.append(candidate_nameid_value) else: msg = "{} NameID {} value also asserted as attribute value".format( logprefix, candidate_nameid_value ) - logline = logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.warn(logline) # If no value was asserted by the IdP for one of the configured list of attribute names # for this candidate then go onto the next candidate. if None in values: msg = "{} Candidate is missing value so skipping".format(logprefix) - logline = logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) continue # All values for the configured list of attribute names are present # so we can create a primary identifer. Add a scope if configured # to do so. - if 'add_scope' in candidate: - if candidate['add_scope'] == 'issuer_entityid': + if "add_scope" in candidate: + if candidate["add_scope"] == "issuer_entityid": scope = data.auth_info.issuer else: - scope = candidate['add_scope'] + scope = candidate["add_scope"] msg = "{} Added scope {} to values".format(logprefix, scope) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) values.append(scope) # Concatenate all values to create the primary identifier. - value = ''.join(values) + value = "".join(values) break return value @@ -136,10 +154,14 @@ def process(self, context, data): # Find the entityID for the SP that initiated the flow try: - spEntityID = context.state.state_dict['SATOSA_BASE']['requester'] + spEntityID = context.state.state_dict["SATOSA_BASE"]["requester"] except KeyError: - msg = "{} Unable to determine the entityID for the SP requester".format(logprefix) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + msg = "{} Unable to determine the entityID for the SP requester".format( + logprefix + ) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) return super().process(context, data) @@ -151,71 +173,89 @@ def process(self, context, data): try: idpEntityID = data.auth_info.issuer except KeyError: - msg = "{} Unable to determine the entityID for the IdP issuer".format(logprefix) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + msg = "{} Unable to determine the entityID for the IdP issuer".format( + logprefix + ) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) return super().process(context, data) # Examine our configuration to determine if there is a per-IdP configuration if idpEntityID in self.config: config = self.config[idpEntityID] - msg = "{} For IdP {} using configuration {}".format(logprefix, idpEntityID, config) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + msg = "{} For IdP {} using configuration {}".format( + logprefix, idpEntityID, config + ) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) # Examine our configuration to determine if there is a per-SP configuration. # An SP configuration overrides an IdP configuration when there is a conflict. if spEntityID in self.config: config = self.config[spEntityID] - msg = "{} For SP {} using configuration {}".format(logprefix, spEntityID, config) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + msg = "{} For SP {} using configuration {}".format( + logprefix, spEntityID, config + ) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) # Obtain configuration details from the per-SP configuration or the default configuration try: - if 'ordered_identifier_candidates' in config: - ordered_identifier_candidates = config['ordered_identifier_candidates'] + if "ordered_identifier_candidates" in config: + ordered_identifier_candidates = config["ordered_identifier_candidates"] else: - ordered_identifier_candidates = self.config['ordered_identifier_candidates'] - if 'primary_identifier' in config: - primary_identifier = config['primary_identifier'] - elif 'primary_identifier' in self.config: - primary_identifier = self.config['primary_identifier'] + ordered_identifier_candidates = self.config[ + "ordered_identifier_candidates" + ] + if "primary_identifier" in config: + primary_identifier = config["primary_identifier"] + elif "primary_identifier" in self.config: + primary_identifier = self.config["primary_identifier"] else: - primary_identifier = 'uid' - if 'clear_input_attributes' in config: - clear_input_attributes = config['clear_input_attributes'] - elif 'clear_input_attributes' in self.config: - clear_input_attributes = self.config['clear_input_attributes'] + primary_identifier = "uid" + if "clear_input_attributes" in config: + clear_input_attributes = config["clear_input_attributes"] + elif "clear_input_attributes" in self.config: + clear_input_attributes = self.config["clear_input_attributes"] else: clear_input_attributes = False - if 'replace_subject_id' in config: - replace_subject_id = config['replace_subject_id'] - elif 'replace_subject_id' in self.config: - replace_subject_id = self.config['replace_subject_id'] + if "replace_subject_id" in config: + replace_subject_id = config["replace_subject_id"] + elif "replace_subject_id" in self.config: + replace_subject_id = self.config["replace_subject_id"] else: replace_subject_id = False - if 'ignore' in config: + if "ignore" in config: ignore = True else: ignore = False - if 'on_error' in config: - on_error = config['on_error'] - elif 'on_error' in self.config: - on_error = self.config['on_error'] + if "on_error" in config: + on_error = config["on_error"] + elif "on_error" in self.config: + on_error = self.config["on_error"] else: on_error = None except KeyError as err: msg = "{} Configuration '{}' is missing".format(logprefix, err) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.error(logline) return super().process(context, data) # Ignore this SP entirely if so configured. if ignore: msg = "{} Ignoring SP {}".format(logprefix, spEntityID) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) return super().process(context, data) @@ -223,11 +263,15 @@ def process(self, context, data): msg = "{} Constructing primary identifier".format(logprefix) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.debug(logline) - primary_identifier_val = self.constructPrimaryIdentifier(data, ordered_identifier_candidates) + primary_identifier_val = self.constructPrimaryIdentifier( + data, ordered_identifier_candidates + ) if not primary_identifier_val: msg = "{} No primary identifier found".format(logprefix) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.warn(logline) if on_error: # Redirect to the configured error handling service with @@ -235,13 +279,19 @@ def process(self, context, data): # as query string parameters (URL encoded). encodedSpEntityID = urllib.parse.quote_plus(spEntityID) encodedIdpEntityID = urllib.parse.quote_plus(data.auth_info.issuer) - url = "{}?sp={}&idp={}".format(on_error, encodedSpEntityID, encodedIdpEntityID) + url = "{}?sp={}&idp={}".format( + on_error, encodedSpEntityID, encodedIdpEntityID + ) msg = "{} Redirecting to {}".format(logprefix, url) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.info(logline) return Redirect(url) - msg = "{} Found primary identifier: {}".format(logprefix, primary_identifier_val) + msg = "{} Found primary identifier: {}".format( + logprefix, primary_identifier_val + ) logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) logger.info(logline) @@ -250,7 +300,9 @@ def process(self, context, data): msg = "{} Clearing values for these input attributes: {}".format( logprefix, data.attributes.keys() ) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) data.attributes = {} @@ -260,7 +312,9 @@ def process(self, context, data): msg = "{} Setting attribute {} to value {}".format( logprefix, primary_identifier, primary_identifier_val ) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) # Replace subject_id with the constructed primary identifier if so configured. @@ -268,7 +322,9 @@ def process(self, context, data): msg = "{} Setting subject_id to value {}".format( logprefix, primary_identifier_val ) - logline = lu.LOG_FMT.format(id=lu.get_session_id(context.state), message=msg) + logline = lu.LOG_FMT.format( + id=lu.get_session_id(context.state), message=msg + ) logger.debug(logline) data.subject_id = primary_identifier_val diff --git a/src/satosa/micro_services/processors/gender_processor.py b/src/satosa/micro_services/processors/gender_processor.py index 282625c52..0dce28c91 100644 --- a/src/satosa/micro_services/processors/gender_processor.py +++ b/src/satosa/micro_services/processors/gender_processor.py @@ -1,13 +1,13 @@ -from .base_processor import BaseProcessor - from enum import Enum, unique +from .base_processor import BaseProcessor + @unique class Gender(Enum): - NOT_KNOWN = 0 - MALE = 1 - FEMALE = 2 + NOT_KNOWN = 0 + MALE = 1 + FEMALE = 2 NOT_SPECIFIED = 9 @@ -18,7 +18,8 @@ def process(self, internal_data, attribute, **kwargs): if value: representation = getattr( - Gender, value.upper().replace(' ', '_'), Gender.NOT_KNOWN) + Gender, value.upper().replace(" ", "_"), Gender.NOT_KNOWN + ) else: representation = Gender.NOT_SPECIFIED diff --git a/src/satosa/micro_services/processors/hash_processor.py b/src/satosa/micro_services/processors/hash_processor.py index 06ad8f928..5f28234d7 100644 --- a/src/satosa/micro_services/processors/hash_processor.py +++ b/src/satosa/micro_services/processors/hash_processor.py @@ -1,13 +1,12 @@ -from ..attribute_processor import AttributeProcessorError -from .base_processor import BaseProcessor - import hashlib +from ..attribute_processor import AttributeProcessorError +from .base_processor import BaseProcessor -CONFIG_KEY_SALT = 'salt' -CONFIG_DEFAULT_SALT = '' -CONFIG_KEY_HASHALGO = 'hash_algo' -CONFIG_DEFAULT_HASHALGO = 'sha256' +CONFIG_KEY_SALT = "salt" +CONFIG_DEFAULT_SALT = "" +CONFIG_KEY_HASHALGO = "hash_algo" +CONFIG_DEFAULT_HASHALGO = "sha256" class HashProcessor(BaseProcessor): @@ -16,16 +15,18 @@ def process(self, internal_data, attribute, **kwargs): hash_algo = kwargs.get(CONFIG_KEY_HASHALGO, CONFIG_DEFAULT_HASHALGO) if hash_algo not in hashlib.algorithms_available: raise AttributeProcessorError( - "Hash algorithm not supported: {}".format(hash_algo)) + "Hash algorithm not supported: {}".format(hash_algo) + ) attributes = internal_data.attributes value = attributes.get(attribute, [None])[0] if value is None: raise AttributeProcessorError( - "No value for attribute: {}".format(attribute)) + "No value for attribute: {}".format(attribute) + ) hasher = hashlib.new(hash_algo) - hasher.update(value.encode('utf-8')) - hasher.update(salt.encode('utf-8')) + hasher.update(value.encode("utf-8")) + hasher.update(salt.encode("utf-8")) value_hashed = hasher.hexdigest() attributes[attribute][0] = value_hashed diff --git a/src/satosa/micro_services/processors/regex_sub_processor.py b/src/satosa/micro_services/processors/regex_sub_processor.py index cb786966b..b2d8a452c 100644 --- a/src/satosa/micro_services/processors/regex_sub_processor.py +++ b/src/satosa/micro_services/processors/regex_sub_processor.py @@ -1,11 +1,14 @@ +import logging +import re + from ..attribute_processor import AttributeProcessorError, AttributeProcessorWarning from .base_processor import BaseProcessor -import re -import logging -CONFIG_KEY_MATCH_PATTERN = 'regex_sub_match_pattern' -CONFIG_KEY_REPLACE_PATTERN = 'regex_sub_replace_pattern' +CONFIG_KEY_MATCH_PATTERN = "regex_sub_match_pattern" +CONFIG_KEY_REPLACE_PATTERN = "regex_sub_replace_pattern" logger = logging.getLogger(__name__) + + class RegexSubProcessor(BaseProcessor): """ Performs a regex sub against an attribute value. @@ -24,20 +27,32 @@ class RegexSubProcessor(BaseProcessor): """ def process(self, internal_data, attribute, **kwargs): - regex_sub_match_pattern = r'{}'.format(kwargs.get(CONFIG_KEY_MATCH_PATTERN, '')) - if regex_sub_match_pattern == '': + regex_sub_match_pattern = r"{}".format(kwargs.get(CONFIG_KEY_MATCH_PATTERN, "")) + if regex_sub_match_pattern == "": raise AttributeProcessorError("The regex_sub_match_pattern needs to be set") - regex_sub_replace_pattern = r'{}'.format(kwargs.get(CONFIG_KEY_REPLACE_PATTERN, '')) - if regex_sub_replace_pattern == '': - raise AttributeProcessorError("The regex_sub_replace_pattern needs to be set") + regex_sub_replace_pattern = r"{}".format( + kwargs.get(CONFIG_KEY_REPLACE_PATTERN, "") + ) + if regex_sub_replace_pattern == "": + raise AttributeProcessorError( + "The regex_sub_replace_pattern needs to be set" + ) attributes = internal_data.attributes values = attributes.get(attribute, []) new_values = [] if not values: - raise AttributeProcessorWarning("Cannot apply regex_sub to {}, it has no values".format(attribute)) + raise AttributeProcessorWarning( + "Cannot apply regex_sub to {}, it has no values".format(attribute) + ) for value in values: - new_values.append(re.sub(r'{}'.format(regex_sub_match_pattern), r'{}'.format(regex_sub_replace_pattern), value)) - logger.debug('regex_sub new_values: {}'.format(new_values)) + new_values.append( + re.sub( + r"{}".format(regex_sub_match_pattern), + r"{}".format(regex_sub_replace_pattern), + value, + ) + ) + logger.debug("regex_sub new_values: {}".format(new_values)) attributes[attribute] = new_values diff --git a/src/satosa/micro_services/processors/scope_extractor_processor.py b/src/satosa/micro_services/processors/scope_extractor_processor.py index 863bc7740..2e64eae19 100644 --- a/src/satosa/micro_services/processors/scope_extractor_processor.py +++ b/src/satosa/micro_services/processors/scope_extractor_processor.py @@ -1,9 +1,8 @@ from ..attribute_processor import AttributeProcessorError, AttributeProcessorWarning from .base_processor import BaseProcessor - -CONFIG_KEY_MAPPEDATTRIBUTE = 'mapped_attribute' -CONFIG_DEFAULT_MAPPEDATTRIBUTE = '' +CONFIG_KEY_MAPPEDATTRIBUTE = "mapped_attribute" +CONFIG_DEFAULT_MAPPEDATTRIBUTE = "" class ScopeExtractorProcessor(BaseProcessor): @@ -22,21 +21,30 @@ class ScopeExtractorProcessor(BaseProcessor): module: satosa.micro_services.processors.scope_extractor_processor mapped_attribute: domain """ + def process(self, internal_data, attribute, **kwargs): - mapped_attribute = kwargs.get(CONFIG_KEY_MAPPEDATTRIBUTE, CONFIG_DEFAULT_MAPPEDATTRIBUTE) - if mapped_attribute is None or mapped_attribute == '': + mapped_attribute = kwargs.get( + CONFIG_KEY_MAPPEDATTRIBUTE, CONFIG_DEFAULT_MAPPEDATTRIBUTE + ) + if mapped_attribute is None or mapped_attribute == "": raise AttributeProcessorError("The mapped_attribute needs to be set") attributes = internal_data.attributes values = attributes.get(attribute, []) if not values: - raise AttributeProcessorWarning("Cannot apply scope_extractor to {}, it has no values".format(attribute)) + raise AttributeProcessorWarning( + "Cannot apply scope_extractor to {}, it has no values".format(attribute) + ) if not isinstance(values, list): values = [values] - if not any('@' in val for val in values): - raise AttributeProcessorWarning("Cannot apply scope_extractor to {}, it's values are not scoped".format(attribute)) + if not any("@" in val for val in values): + raise AttributeProcessorWarning( + "Cannot apply scope_extractor to {}, it's values are not scoped".format( + attribute + ) + ) for value in values: - if '@' in value: - scope = value.split('@')[1] + if "@" in value: + scope = value.split("@")[1] attributes[mapped_attribute] = [scope] break diff --git a/src/satosa/micro_services/processors/scope_processor.py b/src/satosa/micro_services/processors/scope_processor.py index 3173c5e6a..a2d5cf77f 100644 --- a/src/satosa/micro_services/processors/scope_processor.py +++ b/src/satosa/micro_services/processors/scope_processor.py @@ -1,15 +1,14 @@ from ..attribute_processor import AttributeProcessorError from .base_processor import BaseProcessor - -CONFIG_KEY_SCOPE = 'scope' -CONFIG_DEFAULT_SCOPE = '' +CONFIG_KEY_SCOPE = "scope" +CONFIG_DEFAULT_SCOPE = "" class ScopeProcessor(BaseProcessor): def process(self, internal_data, attribute, **kwargs): scope = kwargs.get(CONFIG_KEY_SCOPE, CONFIG_DEFAULT_SCOPE) - if scope is None or scope == '': + if scope is None or scope == "": raise AttributeProcessorError("No scope set.") attributes = internal_data.attributes @@ -17,4 +16,4 @@ def process(self, internal_data, attribute, **kwargs): if not isinstance(values, list): values = [values] if values: - attributes[attribute] = list(v + '@' + scope for v in values) + attributes[attribute] = list(v + "@" + scope for v in values) diff --git a/src/satosa/micro_services/processors/scope_remover_processor.py b/src/satosa/micro_services/processors/scope_remover_processor.py index 6cf878365..fc3570377 100644 --- a/src/satosa/micro_services/processors/scope_remover_processor.py +++ b/src/satosa/micro_services/processors/scope_remover_processor.py @@ -1,18 +1,22 @@ from ..attribute_processor import AttributeProcessorWarning from .base_processor import BaseProcessor + class ScopeRemoverProcessor(BaseProcessor): """ Removes the scope from all values of a given attribute """ + def process(self, internal_data, attribute, **kwargs): attributes = internal_data.attributes new_values = [] values = attributes.get(attribute, []) if not values: - raise AttributeProcessorWarning("Attribute {} has no values".format(attribute)) + raise AttributeProcessorWarning( + "Attribute {} has no values".format(attribute) + ) for value in values: - unscoped_value = value.split('@')[0] + unscoped_value = value.split("@")[0] new_values.append(unscoped_value) attributes[attribute] = new_values diff --git a/src/satosa/plugin_loader.py b/src/satosa/plugin_loader.py index b7eb4cf46..47857a6fb 100644 --- a/src/satosa/plugin_loader.py +++ b/src/satosa/plugin_loader.py @@ -7,13 +7,13 @@ from contextlib import contextmanager from pydoc import locate -from satosa.yaml import load as yaml_load from satosa.yaml import YAMLError +from satosa.yaml import load as yaml_load from .backends.base import BackendModule from .exception import SATOSAConfigurationError from .frontends.base import FrontendModule -from .micro_services.base import (MicroService, RequestMicroService, ResponseMicroService) +from .micro_services.base import MicroService, RequestMicroService, ResponseMicroService logger = logging.getLogger(__name__) @@ -21,10 +21,12 @@ @contextmanager def prepend_to_import_path(import_paths): import_paths = import_paths or [] - for p in reversed(import_paths): # insert the specified plugin paths in the same order + for p in reversed( + import_paths + ): # insert the specified plugin paths in the same order sys.path.insert(0, p) yield - del sys.path[0:len(import_paths)] # restore sys.path + del sys.path[0 : len(import_paths)] # restore sys.path def load_backends(config, callback, internal_attributes): @@ -44,9 +46,14 @@ def load_backends(config, callback, internal_attributes): backend_modules = _load_plugins( config.get("CUSTOM_PLUGIN_MODULE_PATHS"), config["BACKEND_MODULES"], - backend_filter, config["BASE"], - internal_attributes, callback) - logger.info("Setup backends: {}".format([backend.name for backend in backend_modules])) + backend_filter, + config["BASE"], + internal_attributes, + callback, + ) + logger.info( + "Setup backends: {}".format([backend.name for backend in backend_modules]) + ) return backend_modules @@ -65,9 +72,17 @@ def load_frontends(config, callback, internal_attributes): has been processed. :return: A list of frontend modules """ - frontend_modules = _load_plugins(config.get("CUSTOM_PLUGIN_MODULE_PATHS"), config["FRONTEND_MODULES"], - frontend_filter, config["BASE"], internal_attributes, callback) - logger.info("Setup frontends: {}".format([frontend.name for frontend in frontend_modules])) + frontend_modules = _load_plugins( + config.get("CUSTOM_PLUGIN_MODULE_PATHS"), + config["FRONTEND_MODULES"], + frontend_filter, + config["BASE"], + internal_attributes, + callback, + ) + logger.info( + "Setup frontends: {}".format([frontend.name for frontend in frontend_modules]) + ) return frontend_modules @@ -109,7 +124,11 @@ def _micro_service_filter(cls): :return: True if match, else false """ is_microservice_module = issubclass(cls, MicroService) - is_correct_subclass = cls != MicroService and cls != ResponseMicroService and cls != RequestMicroService + is_correct_subclass = ( + cls != MicroService + and cls != ResponseMicroService + and cls != RequestMicroService + ) return is_microservice_module and is_correct_subclass @@ -145,13 +164,19 @@ def _load_plugin_config(config): try: return yaml_load(config) except YAMLError as exc: - if hasattr(exc, 'problem_mark'): + if hasattr(exc, "problem_mark"): mark = exc.problem_mark - logger.error("Error position: ({line}:{column})".format(line=mark.line + 1, column=mark.column + 1)) + logger.error( + "Error position: ({line}:{column})".format( + line=mark.line + 1, column=mark.column + 1 + ) + ) raise SATOSAConfigurationError("The configuration is corrupt.") from exc -def _load_plugins(plugin_paths, plugins, plugin_filter, base_url, internal_attributes, callback): +def _load_plugins( + plugin_paths, plugins, plugin_filter, base_url, internal_attributes, callback +): """ Loads endpoint plugins @@ -173,13 +198,21 @@ def _load_plugins(plugin_paths, plugins, plugin_filter, base_url, internal_attri try: module_class = _load_endpoint_module(plugin_config, plugin_filter) except SATOSAConfigurationError as e: - raise SATOSAConfigurationError("Configuration error in {}".format(json.dumps(plugin_config))) from e + raise SATOSAConfigurationError( + "Configuration error in {}".format(json.dumps(plugin_config)) + ) from e if module_class: - module_config = _replace_variables_in_plugin_module_config(plugin_config["config"], base_url, - plugin_config["name"]) - instance = module_class(callback, internal_attributes, module_config, base_url, - plugin_config["name"]) + module_config = _replace_variables_in_plugin_module_config( + plugin_config["config"], base_url, plugin_config["name"] + ) + instance = module_class( + callback, + internal_attributes, + module_config, + base_url, + plugin_config["name"], + ) loaded_plugin_modules.append(instance) return loaded_plugin_modules @@ -188,7 +221,10 @@ def _load_endpoint_module(plugin_config, plugin_filter): _mandatory_params = ("name", "module", "config") if not all(k in plugin_config for k in _mandatory_params): raise SATOSAConfigurationError( - "Missing mandatory plugin configuration parameter: {}".format(_mandatory_params)) + "Missing mandatory plugin configuration parameter: {}".format( + _mandatory_params + ) + ) return _load_plugin_module(plugin_config, plugin_filter) @@ -207,23 +243,34 @@ def _load_microservice(plugin_config, plugin_filter): _mandatory_params = ("name", "module") if not all(k in plugin_config for k in _mandatory_params): raise SATOSAConfigurationError( - "Missing mandatory plugin configuration parameter: {}".format(_mandatory_params)) + "Missing mandatory plugin configuration parameter: {}".format( + _mandatory_params + ) + ) return _load_plugin_module(plugin_config, plugin_filter) -def _load_microservices(plugin_paths, plugins, plugin_filter, internal_attributes, base_url): +def _load_microservices( + plugin_paths, plugins, plugin_filter, internal_attributes, base_url +): loaded_plugin_modules = [] with prepend_to_import_path(plugin_paths): for plugin_config in plugins: try: module_class = _load_microservice(plugin_config, plugin_filter) except SATOSAConfigurationError as e: - raise SATOSAConfigurationError("Configuration error in {}".format(json.dumps(plugin_config))) from e + raise SATOSAConfigurationError( + "Configuration error in {}".format(json.dumps(plugin_config)) + ) from e if module_class: - instance = module_class(internal_attributes=internal_attributes, config=plugin_config.get("config"), - name=plugin_config["name"], base_url=base_url) + instance = module_class( + internal_attributes=internal_attributes, + config=plugin_config.get("config"), + name=plugin_config["name"], + base_url=base_url, + ) loaded_plugin_modules.append(instance) return loaded_plugin_modules @@ -231,10 +278,7 @@ def _load_microservices(plugin_paths, plugins, plugin_filter, internal_attribute def _replace_variables_in_plugin_module_config(module_config, base_url, name): config = json.dumps(module_config) - replace = [ - ("", base_url), - ("", name) - ] + replace = [("", base_url), ("", name)] for _replace in replace: config = config.replace(_replace[0], _replace[1]) return json.loads(config) @@ -255,9 +299,18 @@ def load_request_microservices(plugin_path, plugins, internal_attributes, base_u :param: base_url: base url of the SATOSA server :return: Request micro service """ - request_services = _load_microservices(plugin_path, plugins, _request_micro_service_filter, internal_attributes, - base_url) - logger.info("Loaded request micro services: {}".format([type(k).__name__ for k in request_services])) + request_services = _load_microservices( + plugin_path, + plugins, + _request_micro_service_filter, + internal_attributes, + base_url, + ) + logger.info( + "Loaded request micro services: {}".format( + [type(k).__name__ for k in request_services] + ) + ) return request_services @@ -276,7 +329,16 @@ def load_response_microservices(plugin_path, plugins, internal_attributes, base_ :param: base_url: base url of the SATOSA server :return: Response micro service """ - response_services = _load_microservices(plugin_path, plugins, _response_micro_service_filter, internal_attributes, - base_url) - logger.info("Loaded response micro services:{}".format([type(k).__name__ for k in response_services])) + response_services = _load_microservices( + plugin_path, + plugins, + _response_micro_service_filter, + internal_attributes, + base_url, + ) + logger.info( + "Loaded response micro services:{}".format( + [type(k).__name__ for k in response_services] + ) + ) return response_services diff --git a/src/satosa/proxy_server.py b/src/satosa/proxy_server.py index e23be1418..4f312b9db 100644 --- a/src/satosa/proxy_server.py +++ b/src/satosa/proxy_server.py @@ -10,9 +10,7 @@ from .base import SATOSABase from .context import Context -from .response import ServiceError -from .response import NotFound - +from .response import NotFound, ServiceError logger = logging.getLogger(__name__) @@ -21,7 +19,9 @@ def parse_query_string(data): query_param_pairs = _parse_query_string(data) query_param_dict = dict(query_param_pairs) if "resource" in query_param_dict: - query_param_dict["resource"] = [t[1] for t in query_param_pairs if t[0] == "resource"] + query_param_dict["resource"] = [ + t[1] for t in query_param_pairs if t[0] == "resource" + ] return query_param_dict @@ -40,7 +40,7 @@ def unpack_post(environ, content_length): :param environ: whiskey application environment. :return: A dictionary with parameters. """ - post_body = environ['wsgi.input'].read(content_length).decode("utf-8") + post_body = environ["wsgi.input"].read(content_length).decode("utf-8") data = None if "application/x-www-form-urlencoded" in environ["CONTENT_TYPE"]: data = parse_query_string(post_body) @@ -82,10 +82,7 @@ def collect_http_headers(environ): headers = { header_name: header_value for header_name, header_value in environ.items() - if ( - header_name.startswith("HTTP_") - or header_name.startswith("REMOTE_") - ) + if (header_name.startswith("HTTP_") or header_name.startswith("REMOTE_")) } return headers @@ -119,7 +116,7 @@ def __init__(self, config): super().__init__(config) def __call__(self, environ, start_response, debug=False): - path = environ.get('PATH_INFO', '').lstrip('/') + path = environ.get("PATH_INFO", "").lstrip("/") if ".." in path or path == "": resp = NotFound("Couldn't find the page you asked for!") return resp(environ, start_response) @@ -129,9 +126,9 @@ def __call__(self, environ, start_response, debug=False): # copy wsgi.input stream to allow it to be re-read later by satosa plugins # see: http://stackoverflow.com/questions/1783383/how-do-i-copy-wsgi-input-if-i-want-to-process-post-data-more-than-once - content_length = int(environ.get('CONTENT_LENGTH', '0') or '0') - body = BytesIO(environ['wsgi.input'].read(content_length)) - environ['wsgi.input'] = body + content_length = int(environ.get("CONTENT_LENGTH", "0") or "0") + body = BytesIO(environ["wsgi.input"].read(content_length)) + environ["wsgi.input"] = body context.request = unpack_request(environ, content_length) context.request_uri = environ.get("REQUEST_URI") @@ -140,9 +137,11 @@ def __call__(self, environ, start_response, debug=False): context.server = collect_server_headers(environ) context.http_headers = collect_http_headers(environ) context.cookie = context.http_headers.get("HTTP_COOKIE", "") - context.request_authorization = context.http_headers.get("HTTP_AUTHORIZATION", "") + context.request_authorization = context.http_headers.get( + "HTTP_AUTHORIZATION", "" + ) - environ['wsgi.input'].seek(0) + environ["wsgi.input"].seek(0) logline = { "message": "Proxy server received request", diff --git a/src/satosa/response.py b/src/satosa/response.py index b672b0e40..7d1910d7e 100644 --- a/src/satosa/response.py +++ b/src/satosa/response.py @@ -7,6 +7,7 @@ class Response(object): """ A response object """ + # _template = None _status = "200 OK" _content_type = "text/html" @@ -29,7 +30,9 @@ def __init__(self, message=None, status=None, headers=None, content=None): self.headers = headers if headers is not None else [] self.message = message - should_add_content_type = not any(header[0].lower() == "content-type" for header in self.headers) + should_add_content_type = not any( + header[0].lower() == "content-type" for header in self.headers + ) if should_add_content_type: self.headers.append(("Content-Type", _content_type)) @@ -53,6 +56,7 @@ class Redirect(Response): """ A Redirect response """ + _status = "302 Found" def __init__(self, redirect_url, headers=None, content=None): @@ -75,6 +79,7 @@ class SeeOther(Redirect): """ A SeeOther response """ + _status = "303 See Other" def __init__(self, redirect_url, headers=None, content=None): diff --git a/src/satosa/routing.py b/src/satosa/routing.py index 015cffb23..9b43c0c31 100644 --- a/src/satosa/routing.py +++ b/src/satosa/routing.py @@ -4,11 +4,8 @@ import logging import re -from satosa.exception import SATOSABadContextError -from satosa.exception import SATOSANoBoundEndpointError - import satosa.logging_util as lu - +from satosa.exception import SATOSABadContextError, SATOSANoBoundEndpointError logger = logging.getLogger(__name__) @@ -42,15 +39,29 @@ def __init__(self, frontends, backends, micro_services): raise ValueError("Need at least one frontend and one backend") backend_names = [backend.name for backend in backends] - self.frontends = {instance.name: {"instance": instance, - "endpoints": instance.register_endpoints(backend_names)} - for instance in frontends} - self.backends = {instance.name: {"instance": instance, "endpoints": instance.register_endpoints()} - for instance in backends} + self.frontends = { + instance.name: { + "instance": instance, + "endpoints": instance.register_endpoints(backend_names), + } + for instance in frontends + } + self.backends = { + instance.name: { + "instance": instance, + "endpoints": instance.register_endpoints(), + } + for instance in backends + } if micro_services: - self.micro_services = {instance.name: {"instance": instance, "endpoints": instance.register_endpoints()} - for instance in micro_services} + self.micro_services = { + instance.name: { + "instance": instance, + "endpoints": instance.register_endpoints(), + } + for instance in micro_services + } else: self.micro_services = {} @@ -110,7 +121,9 @@ def _find_registered_endpoint_for_module(self, module, context): return None def _find_registered_backend_endpoint(self, context): - return self._find_registered_endpoint_for_module(self.backends[context.target_backend], context) + return self._find_registered_endpoint_for_module( + self.backends[context.target_backend], context + ) def _find_registered_endpoint(self, context, modules): for module in modules.values(): @@ -154,7 +167,9 @@ def endpoint_routing(self, context): logger.debug(logline) try: - name, frontend_endpoint = self._find_registered_endpoint(context, self.frontends) + name, frontend_endpoint = self._find_registered_endpoint( + context, self.frontends + ) except ModuleRouter.UnknownEndpoint: pass else: @@ -162,7 +177,9 @@ def endpoint_routing(self, context): return frontend_endpoint try: - name, micro_service_endpoint = self._find_registered_endpoint(context, self.micro_services) + name, micro_service_endpoint = self._find_registered_endpoint( + context, self.micro_services + ) except ModuleRouter.UnknownEndpoint: pass else: @@ -174,4 +191,6 @@ def endpoint_routing(self, context): if backend_endpoint: return backend_endpoint - raise SATOSANoBoundEndpointError("'{}' not bound to any function".format(context.path)) + raise SATOSANoBoundEndpointError( + "'{}' not bound to any function".format(context.path) + ) diff --git a/src/satosa/saml_util.py b/src/satosa/saml_util.py index fced07568..569fb78ac 100644 --- a/src/satosa/saml_util.py +++ b/src/satosa/saml_util.py @@ -1,6 +1,6 @@ from saml2 import BINDING_HTTP_REDIRECT -from .response import SeeOther, Response +from .response import Response, SeeOther def make_saml_response(binding, http_args): diff --git a/src/satosa/satosa_config.py b/src/satosa/satosa_config.py index d45280c41..87f356c8d 100644 --- a/src/satosa/satosa_config.py +++ b/src/satosa/satosa_config.py @@ -6,9 +6,8 @@ import os.path from satosa.exception import SATOSAConfigurationError -from satosa.yaml import load as yaml_load from satosa.yaml import YAMLError - +from satosa.yaml import load as yaml_load logger = logging.getLogger(__name__) @@ -18,9 +17,15 @@ class SATOSAConfig(object): A configuration class for the satosa proxy. Verifies that the given config holds all the necessary parameters. """ + sensitive_dict_keys = ["STATE_ENCRYPTION_KEY"] - mandatory_dict_keys = ["BASE", "BACKEND_MODULES", "FRONTEND_MODULES", - "INTERNAL_ATTRIBUTES", "COOKIE_STATE_NAME"] + mandatory_dict_keys = [ + "BASE", + "BACKEND_MODULES", + "FRONTEND_MODULES", + "INTERNAL_ATTRIBUTES", + "COOKIE_STATE_NAME", + ] def __init__(self, config): """ @@ -56,7 +61,9 @@ def __init__(self, config): plugin_configs.append(plugin_config) break else: - raise SATOSAConfigurationError(f"Failed to load plugin config '{config}'") + raise SATOSAConfigurationError( + f"Failed to load plugin config '{config}'" + ) self._config[key] = plugin_configs for parser in parsers: @@ -65,7 +72,9 @@ def __init__(self, config): self._config["INTERNAL_ATTRIBUTES"] = _internal_attributes break if not self._config["INTERNAL_ATTRIBUTES"]: - raise SATOSAConfigurationError("Could not load attribute mapping from 'INTERNAL_ATTRIBUTES.") + raise SATOSAConfigurationError( + "Could not load attribute mapping from 'INTERNAL_ATTRIBUTES." + ) def _verify_dict(self, conf): """ @@ -87,7 +96,9 @@ def _verify_dict(self, conf): for key in SATOSAConfig.sensitive_dict_keys: if key not in conf and f"SATOSA_{key}" not in os.environ: - raise SATOSAConfigurationError(f"Missing key '{key}' from config and ENVIRONMENT") + raise SATOSAConfigurationError( + f"Missing key '{key}' from config and ENVIRONMENT" + ) def __getitem__(self, item): """ @@ -151,9 +162,13 @@ def _load_yaml(self, config_file): return yaml_load(f.read()) except YAMLError as exc: logger.error("Could not parse config as YAML: {}".format(exc)) - if hasattr(exc, 'problem_mark'): + if hasattr(exc, "problem_mark"): mark = exc.problem_mark - logger.error("Error position: ({line}:{column})".format(line=mark.line + 1, column=mark.column + 1)) + logger.error( + "Error position: ({line}:{column})".format( + line=mark.line + 1, column=mark.column + 1 + ) + ) except IOError as e: logger.error("Could not open config file: {}".format(e)) diff --git a/src/satosa/scripts/satosa_saml_metadata.py b/src/satosa/scripts/satosa_saml_metadata.py index c0638d8b7..65c21023a 100644 --- a/src/satosa/scripts/satosa_saml_metadata.py +++ b/src/satosa/scripts/satosa_saml_metadata.py @@ -4,9 +4,11 @@ from saml2.config import Config from saml2.sigver import security_context -from ..metadata_creation.saml_metadata import create_entity_descriptors -from ..metadata_creation.saml_metadata import create_entity_descriptor_metadata -from ..metadata_creation.saml_metadata import create_signed_entity_descriptor +from ..metadata_creation.saml_metadata import ( + create_entity_descriptor_metadata, + create_entity_descriptors, + create_signed_entity_descriptor, +) from ..satosa_config import SATOSAConfig @@ -45,8 +47,16 @@ def _create_merged_entities_descriptors(entities, secc, valid, name, sign=True): return output -def create_and_write_saml_metadata(proxy_conf, key, cert, dir, valid, split_frontend_metadata=False, - split_backend_metadata=False, sign=True): +def create_and_write_saml_metadata( + proxy_conf, + key, + cert, + dir, + valid, + split_frontend_metadata=False, + split_backend_metadata=False, + sign=True, +): """ Generates SAML metadata for the given PROXY_CONF, signed with the given KEY and associated CERT. """ @@ -61,14 +71,26 @@ def create_and_write_saml_metadata(proxy_conf, key, cert, dir, valid, split_fron output = [] if frontend_entities: if split_frontend_metadata: - output.extend(_create_split_entity_descriptors(frontend_entities, secc, valid, sign)) + output.extend( + _create_split_entity_descriptors(frontend_entities, secc, valid, sign) + ) else: - output.extend(_create_merged_entities_descriptors(frontend_entities, secc, valid, "frontend.xml", sign)) + output.extend( + _create_merged_entities_descriptors( + frontend_entities, secc, valid, "frontend.xml", sign + ) + ) if backend_entities: if split_backend_metadata: - output.extend(_create_split_entity_descriptors(backend_entities, secc, valid, sign)) + output.extend( + _create_split_entity_descriptors(backend_entities, secc, valid, sign) + ) else: - output.extend(_create_merged_entities_descriptors(backend_entities, secc, valid, "backend.xml", sign)) + output.extend( + _create_merged_entities_descriptors( + backend_entities, secc, valid, "backend.xml", sign + ) + ) for metadata, filename in output: path = os.path.join(dir, filename) @@ -81,16 +103,49 @@ def create_and_write_saml_metadata(proxy_conf, key, cert, dir, valid, split_fron @click.argument("proxy_conf") @click.argument("key", required=False) @click.argument("cert", required=False) -@click.option("--dir", - type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True, readable=False, - resolve_path=False), - default=".", help="Where the output files should be written.") -@click.option("--valid", type=click.INT, default=None, help="Number of hours the metadata should be valid.") -@click.option("--split-frontend", is_flag=True, type=click.BOOL, default=False, - help="Create one entity descriptor per file for the frontend metadata") -@click.option("--split-backend", is_flag=True, type=click.BOOL, default=False, - help="Create one entity descriptor per file for the backend metadata") -@click.option("--sign/--no-sign", is_flag=True, type=click.BOOL, default=True, - help="Sign the generated metadata") -def construct_saml_metadata(proxy_conf, key, cert, dir, valid, split_frontend, split_backend, sign): - create_and_write_saml_metadata(proxy_conf, key, cert, dir, valid, split_frontend, split_backend, sign) +@click.option( + "--dir", + type=click.Path( + exists=True, + file_okay=False, + dir_okay=True, + writable=True, + readable=False, + resolve_path=False, + ), + default=".", + help="Where the output files should be written.", +) +@click.option( + "--valid", + type=click.INT, + default=None, + help="Number of hours the metadata should be valid.", +) +@click.option( + "--split-frontend", + is_flag=True, + type=click.BOOL, + default=False, + help="Create one entity descriptor per file for the frontend metadata", +) +@click.option( + "--split-backend", + is_flag=True, + type=click.BOOL, + default=False, + help="Create one entity descriptor per file for the backend metadata", +) +@click.option( + "--sign/--no-sign", + is_flag=True, + type=click.BOOL, + default=True, + help="Sign the generated metadata", +) +def construct_saml_metadata( + proxy_conf, key, cert, dir, valid, split_frontend, split_backend, sign +): + create_and_write_saml_metadata( + proxy_conf, key, cert, dir, valid, split_frontend, split_backend, sign + ) diff --git a/src/satosa/state.py b/src/satosa/state.py index 1fc768425..c9e59556f 100644 --- a/src/satosa/state.py +++ b/src/satosa/state.py @@ -8,18 +8,16 @@ import json import logging from collections import UserDict -from satosa.cookies import SimpleCookie -from uuid import uuid4 - from lzma import LZMACompressor, LZMADecompressor +from uuid import uuid4 from Cryptodome import Random from Cryptodome.Cipher import AES import satosa.logging_util as lu +from satosa.cookies import SimpleCookie from satosa.exception import SATOSAStateError - logger = logging.getLogger(__name__) _SESSION_ID_KEY = "SESSION_ID" @@ -49,7 +47,9 @@ def __init__(self, urlstate_data=None, encryption_key=None): urlstate_data = {} if urlstate_data is None else urlstate_data if urlstate_data and not encryption_key: - raise ValueError("If an 'urlstate_data' is supplied 'encrypt_key' must be specified.") + raise ValueError( + "If an 'urlstate_data' is supplied 'encrypt_key' must be specified." + ) if urlstate_data: try: @@ -61,7 +61,9 @@ def __init__(self, urlstate_data=None, encryption_key=None): urlstate_data_decompressed ) lzma = LZMADecompressor() - urlstate_data_decrypted_decompressed = lzma.decompress(urlstate_data_decrypted) + urlstate_data_decrypted_decompressed = lzma.decompress( + urlstate_data_decrypted + ) urlstate_data_obj = json.loads(urlstate_data_decrypted_decompressed) except Exception as e: error_context = { @@ -162,11 +164,7 @@ def state_to_cookie( cookie[name]["httponly"] = httponly if httponly is not None else "" cookie[name]["samesite"] = samesite if samesite is not None else "None" cookie[name]["max-age"] = ( - 0 - if state.delete - else max_age - if max_age is not None - else "" + 0 if state.delete else max_age if max_age is not None else "" ) msg = "Saved state in cookie {name} with properties {props}".format( @@ -196,10 +194,10 @@ def cookie_to_state(cookie_str: str, name: str, encryption_key: str) -> State: cookie = SimpleCookie(cookie_str) state = State(cookie[name].value, encryption_key) except KeyError as e: - msg = f'No cookie named {name} in {cookie_str}' + msg = f"No cookie named {name} in {cookie_str}" raise SATOSAStateError(msg) from e except ValueError as e: - msg_tmpl = 'Failed to process {name} from {data}' + msg_tmpl = "Failed to process {name} from {data}" msg = msg_tmpl.format(name=name, data=cookie_str) raise SATOSAStateError(msg) from e else: @@ -251,9 +249,9 @@ def decrypt(self, enc): :return: The decrypted value. """ enc = base64.urlsafe_b64decode(enc) - iv = enc[:AES.block_size] + iv = enc[: AES.block_size] cipher = AES.new(self.key, AES.MODE_CBC, iv) - return self._unpad(cipher.decrypt(enc[AES.block_size:])) + return self._unpad(cipher.decrypt(enc[AES.block_size :])) def _pad(self, b): """ @@ -262,7 +260,9 @@ def _pad(self, b): :type b: bytes :rtype: bytes """ - return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8") + return b + (self.bs - len(b) % self.bs) * chr( + self.bs - len(b) % self.bs + ).encode("UTF-8") @staticmethod def _unpad(b): @@ -272,4 +272,4 @@ def _unpad(b): :type b: bytes :rtype: bytes """ - return b[:-ord(b[len(b) - 1:])] + return b[: -ord(b[len(b) - 1 :])] diff --git a/src/satosa/util.py b/src/satosa/util.py index 9b5d63fc1..f9e15d77f 100644 --- a/src/satosa/util.py +++ b/src/satosa/util.py @@ -6,7 +6,6 @@ import random import string - logger = logging.getLogger(__name__) @@ -22,17 +21,17 @@ def hash_data(salt, value, hash_alg=None): :param value: value to hash together with the salt :return: hashed value """ - hash_alg = hash_alg or 'sha512' + hash_alg = hash_alg or "sha512" hasher = hashlib.new(hash_alg) - hasher.update(value.encode('utf-8')) - hasher.update(salt.encode('utf-8')) + hasher.update(value.encode("utf-8")) + hasher.update(salt.encode("utf-8")) value_hashed = hasher.hexdigest() return value_hashed def check_set_dict_defaults(dic, spec): for path, value in spec.items(): - keys = path.split('.') + keys = path.split(".") try: _val = dict_get_nested(dic, keys) except KeyError: diff --git a/src/satosa/version.py b/src/satosa/version.py index cac85faf0..1140a8eb4 100644 --- a/src/satosa/version.py +++ b/src/satosa/version.py @@ -1,7 +1,9 @@ try: from importlib.metadata import version as _resolve_package_version except ImportError: - from importlib_metadata import version as _resolve_package_version # type: ignore[no-redef] + from importlib_metadata import ( + version as _resolve_package_version, # type: ignore[no-redef] + ) def _parse_version(): diff --git a/src/satosa/wsgi.py b/src/satosa/wsgi.py index 86220eb06..a4d117820 100644 --- a/src/satosa/wsgi.py +++ b/src/satosa/wsgi.py @@ -27,9 +27,7 @@ def main(): sys.exit(1) ssl_context = ( - (args.certfile, args.keyfile) - if args.keyfile and args.certfile - else None + (args.certfile, args.keyfile) if args.keyfile and args.certfile else None ) host = args.host or "localhost" run_simple(host, args.port, app, ssl_context=ssl_context) diff --git a/tests/conftest.py b/tests/conftest.py index f0602a028..c8492f783 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,14 +2,18 @@ import os import pytest -from saml2 import BINDING_HTTP_REDIRECT, BINDING_HTTP_POST +from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT from saml2.extension.idpdisc import BINDING_DISCO -from saml2.saml import NAME_FORMAT_URI, NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_PERSISTENT +from saml2.saml import ( + NAME_FORMAT_URI, + NAMEID_FORMAT_PERSISTENT, + NAMEID_FORMAT_TRANSIENT, +) from satosa.context import Context from satosa.state import State -from .util import create_metadata_from_config_dict -from .util import generate_cert, write_cert + +from .util import create_metadata_from_config_dict, generate_cert, write_cert BASE_URL = "https://test-proxy.com" @@ -47,11 +51,11 @@ def sp_conf(cert_and_key): "assertion_consumer_service": [ ("%s/acs/redirect" % sp_base, BINDING_HTTP_REDIRECT) ], - "discovery_response": [("%s/disco" % sp_base, BINDING_DISCO)] + "discovery_response": [("%s/disco" % sp_base, BINDING_DISCO)], }, "want_response_signed": False, "allow_unsolicited": True, - "name_id_format": [NAMEID_FORMAT_PERSISTENT] + "name_id_format": [NAMEID_FORMAT_PERSISTENT], }, }, "cert_file": cert_and_key[0], @@ -82,18 +86,25 @@ def idp_conf(cert_and_key): "lifetime": {"minutes": 15}, "attribute_restrictions": None, # means all I have "name_form": NAME_FORMAT_URI, - "fail_on_missing_requested": False + "fail_on_missing_requested": False, }, }, "subject_data": {}, - "name_id_format": [NAMEID_FORMAT_TRANSIENT, - NAMEID_FORMAT_PERSISTENT], + "name_id_format": [NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_PERSISTENT], "want_authn_requests_signed": False, "ui_info": { "display_name": [{"text": "SATOSA Test IdP", "lang": "en"}], - "description": [{"text": "Test IdP for SATOSA unit tests.", "lang": "en"}], - "logo": [{"text": "https://idp.example.com/static/logo.png", "width": "120", "height": "60", - "lang": "en"}], + "description": [ + {"text": "Test IdP for SATOSA unit tests.", "lang": "en"} + ], + "logo": [ + { + "text": "https://idp.example.com/static/logo.png", + "width": "120", + "height": "60", + "lang": "en", + } + ], }, }, }, @@ -103,15 +114,22 @@ def idp_conf(cert_and_key): "organization": { "name": [["Test IdP Org.", "en"]], "display_name": [["Test IdP", "en"]], - "url": [["https://idp.example.com/about", "en"]] + "url": [["https://idp.example.com/about", "en"]], }, "contact_person": [ - {"given_name": "Test IdP", "sur_name": "Support", "email_address": ["help@idp.example.com"], - "contact_type": "support" - }, - {"given_name": "Test IdP", "sur_name": "Tech support", - "email_address": ["tech@idp.example.com"], "contact_type": "technical"} - ] + { + "given_name": "Test IdP", + "sur_name": "Support", + "email_address": ["help@idp.example.com"], + "contact_type": "support", + }, + { + "given_name": "Test IdP", + "sur_name": "Tech support", + "email_address": ["tech@idp.example.com"], + "contact_type": "technical", + }, + ], } return idpconfig @@ -125,8 +143,12 @@ def context(): @pytest.fixture -def satosa_config_dict(backend_plugin_config, frontend_plugin_config, request_microservice_config, - response_microservice_config): +def satosa_config_dict( + backend_plugin_config, + frontend_plugin_config, + request_microservice_config, + response_microservice_config, +): config = { "BASE": BASE_URL, "COOKIE_STATE_NAME": "TEST_STATE", @@ -136,28 +158,20 @@ def satosa_config_dict(backend_plugin_config, frontend_plugin_config, request_mi "BACKEND_MODULES": [backend_plugin_config], "FRONTEND_MODULES": [frontend_plugin_config], "MICRO_SERVICES": [request_microservice_config, response_microservice_config], - "LOGGING": {"version": 1} + "LOGGING": {"version": 1}, } return config @pytest.fixture def backend_plugin_config(): - data = { - "module": "util.TestBackend", - "name": "backend", - "config": {"foo": "bar"} - } + data = {"module": "util.TestBackend", "name": "backend", "config": {"foo": "bar"}} return data @pytest.fixture def frontend_plugin_config(): - data = { - "module": "util.TestFrontend", - "name": "frontend", - "config": {"abc": "xyz"} - } + data = {"module": "util.TestFrontend", "name": "frontend", "config": {"abc": "xyz"}} return data @@ -175,7 +189,7 @@ def response_microservice_config(): data = { "module": "util.TestResponseMicroservice", "name": "response-microservice", - "config": {"qwe": "rty"} + "config": {"qwe": "rty"}, } return data @@ -190,9 +204,7 @@ def saml_frontend_config(cert_and_key, sp_conf): "entityid": "frontend-entity_id", "service": { "idp": { - "endpoints": { - "single_sign_on_service": [] - }, + "endpoints": {"single_sign_on_service": []}, "name": "Frontend IdP", "name_id_format": NAMEID_FORMAT_TRANSIENT, "policy": { @@ -200,9 +212,9 @@ def saml_frontend_config(cert_and_key, sp_conf): "attribute_restrictions": None, "fail_on_missing_requested": False, "lifetime": {"minutes": 15}, - "name_form": NAME_FORMAT_URI + "name_form": NAME_FORMAT_URI, } - } + }, } }, "cert_file": cert_and_key[0], @@ -211,23 +223,30 @@ def saml_frontend_config(cert_and_key, sp_conf): "organization": { "name": [["SATOSA Org.", "en"]], "display_name": [["SATOSA", "en"]], - "url": [["https://satosa.example.com/about", "en"]] + "url": [["https://satosa.example.com/about", "en"]], }, "contact_person": [ - {"given_name": "SATOSA", "sur_name": "Support", "email_address": ["help@satosa.example.com"], - "contact_type": "support" - }, - {"given_name": "SATOSA", "sur_name": "Tech Support", "email_address": ["tech@satosa.example.com"], - "contact_type": "technical" - } - ] + { + "given_name": "SATOSA", + "sur_name": "Support", + "email_address": ["help@satosa.example.com"], + "contact_type": "support", + }, + { + "given_name": "SATOSA", + "sur_name": "Tech Support", + "email_address": ["tech@satosa.example.com"], + "contact_type": "technical", + }, + ], }, - "endpoints": { - "single_sign_on_service": {BINDING_HTTP_POST: "sso/post", - BINDING_HTTP_REDIRECT: "sso/redirect"} - } - } + "single_sign_on_service": { + BINDING_HTTP_POST: "sso/post", + BINDING_HTTP_REDIRECT: "sso/redirect", + } + }, + }, } return data @@ -242,12 +261,22 @@ def saml_backend_config(idp_conf): "config": { "sp_config": { "entityid": "backend-entity_id", - "organization": {"display_name": "Example Identities", "name": "Test Identities Org.", - "url": "http://www.example.com"}, + "organization": { + "display_name": "Example Identities", + "name": "Test Identities Org.", + "url": "http://www.example.com", + }, "contact_person": [ - {"contact_type": "technical", "email_address": "technical@example.com", - "given_name": "Technical"}, - {"contact_type": "support", "email_address": "support@example.com", "given_name": "Support"} + { + "contact_type": "technical", + "email_address": "technical@example.com", + "given_name": "Technical", + }, + { + "contact_type": "support", + "email_address": "support@example.com", + "given_name": "Support", + }, ], "service": { "sp": { @@ -255,15 +284,18 @@ def saml_backend_config(idp_conf): "allow_unsolicited": True, "endpoints": { "assertion_consumer_service": [ - ("{}/{}/acs/redirect".format(BASE_URL, name), BINDING_HTTP_REDIRECT)], - "discovery_response": [("{}/disco", BINDING_DISCO)] - - } + ( + "{}/{}/acs/redirect".format(BASE_URL, name), + BINDING_HTTP_REDIRECT, + ) + ], + "discovery_response": [("{}/disco", BINDING_DISCO)], + }, } }, - "metadata": {"inline": [create_metadata_from_config_dict(idp_conf)]} + "metadata": {"inline": [create_metadata_from_config_dict(idp_conf)]}, } - } + }, } return data @@ -284,12 +316,12 @@ def oidc_backend_config(): "config": { "provider_metadata": { "issuer": "https://op.example.com", - "authorization_endpoint": "https://example.com/authorization" + "authorization_endpoint": "https://example.com/authorization", }, "client": { "auth_req_params": { "response_type": "code", - "scope": "openid, profile, email, address, phone" + "scope": "openid, profile, email, address, phone", }, "client_metadata": { "client_id": "backend_client", @@ -298,30 +330,39 @@ def oidc_backend_config(): "contacts": ["suppert@example.com"], "redirect_uris": ["http://example.com/OIDCBackend"], "subject_type": "public", - } + }, }, "entity_info": { - "contact_person": [{ - "contact_type": "technical", - "email_address": ["technical_test@example.com", "support_test@example.com"], - "given_name": "Test", - "sur_name": "OP" - }, { - "contact_type": "support", - "email_address": ["support_test@example.com"], - "given_name": "Support_test" - }], + "contact_person": [ + { + "contact_type": "technical", + "email_address": [ + "technical_test@example.com", + "support_test@example.com", + ], + "given_name": "Test", + "sur_name": "OP", + }, + { + "contact_type": "support", + "email_address": ["support_test@example.com"], + "given_name": "Support_test", + }, + ], "organization": { "display_name": ["OP Identities", "en"], "name": [["En test-OP", "se"], ["A test OP", "en"]], - "url": [["http://www.example.com", "en"], ["http://www.example.se", "se"]], + "url": [ + ["http://www.example.com", "en"], + ["http://www.example.se", "se"], + ], "ui_info": { "description": [["This is a test OP", "en"]], - "display_name": [["OP - TEST", "en"]] - } - } - } - } + "display_name": [["OP - TEST", "en"]], + }, + }, + }, + }, } return data @@ -336,7 +377,7 @@ def account_linking_module_config(signing_key_path): "api_url": "http://account.example.com/api", "redirect_url": "http://account.example.com/redirect", "sign_key": signing_key_path, - } + }, } return account_linking_config @@ -350,6 +391,6 @@ def consent_module_config(signing_key_path): "api_url": "http://consent.example.com/api", "redirect_url": "http://consent.example.com/redirect", "sign_key": signing_key_path, - } + }, } return consent_config diff --git a/tests/flows/test_account_linking.py b/tests/flows/test_account_linking.py index 94f53a431..5aece588c 100644 --- a/tests/flows/test_account_linking.py +++ b/tests/flows/test_account_linking.py @@ -18,22 +18,32 @@ def test_full_flow(self, satosa_config_dict, account_linking_module_config): test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # incoming auth req - http_resp = test_client.get("/{}/{}/request".format(satosa_config_dict["BACKEND_MODULES"][0]["name"], - satosa_config_dict["FRONTEND_MODULES"][0]["name"])) + http_resp = test_client.get( + "/{}/{}/request".format( + satosa_config_dict["BACKEND_MODULES"][0]["name"], + satosa_config_dict["FRONTEND_MODULES"][0]["name"], + ) + ) assert http_resp.status_code == 200 with responses.RequestsMock() as rsps: # fake no previous account linking - rsps.add(responses.GET, "{}/get_id".format(api_url), "test_ticket", status=404) + rsps.add( + responses.GET, "{}/get_id".format(api_url), "test_ticket", status=404 + ) # incoming auth resp - http_resp = test_client.get("/{}/response".format(satosa_config_dict["BACKEND_MODULES"][0]["name"])) + http_resp = test_client.get( + "/{}/response".format(satosa_config_dict["BACKEND_MODULES"][0]["name"]) + ) assert http_resp.status_code == 302 assert http_resp.headers["Location"].startswith(redirect_url) with responses.RequestsMock() as rsps: # fake previous account linking - rsps.add(responses.GET, "{}/get_id".format(api_url), "test_userid", status=200) + rsps.add( + responses.GET, "{}/get_id".format(api_url), "test_userid", status=200 + ) # incoming account linking response http_resp = test_client.get("/account_linking/handle_account_linking") diff --git a/tests/flows/test_consent.py b/tests/flows/test_consent.py index 76dff496b..4b0bc9253 100644 --- a/tests/flows/test_consent.py +++ b/tests/flows/test_consent.py @@ -21,8 +21,12 @@ def test_full_flow(self, satosa_config_dict, consent_module_config): test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # incoming auth req - http_resp = test_client.get("/{}/{}/request".format(satosa_config_dict["BACKEND_MODULES"][0]["name"], - satosa_config_dict["FRONTEND_MODULES"][0]["name"])) + http_resp = test_client.get( + "/{}/{}/request".format( + satosa_config_dict["BACKEND_MODULES"][0]["name"], + satosa_config_dict["FRONTEND_MODULES"][0]["name"], + ) + ) assert http_resp.status_code == 200 verify_url_re = re.compile(r"{}/verify/\w+".format(api_url)) @@ -34,13 +38,17 @@ def test_full_flow(self, satosa_config_dict, consent_module_config): rsps.add(responses.GET, consent_request_url_re, "test_ticket", status=200) # incoming auth resp - http_resp = test_client.get("/{}/response".format(satosa_config_dict["BACKEND_MODULES"][0]["name"])) + http_resp = test_client.get( + "/{}/response".format(satosa_config_dict["BACKEND_MODULES"][0]["name"]) + ) assert http_resp.status_code == 302 assert http_resp.headers["Location"].startswith(redirect_url) with responses.RequestsMock() as rsps: # fake consent - rsps.add(responses.GET, verify_url_re, json.dumps({"foo": "bar"}), status=200) + rsps.add( + responses.GET, verify_url_re, json.dumps({"foo": "bar"}), status=200 + ) # incoming consent response http_resp = test_client.get("/consent/handle_consent") diff --git a/tests/flows/test_oidc-saml.py b/tests/flows/test_oidc-saml.py index 2a299bfef..aa038eda4 100644 --- a/tests/flows/test_oidc-saml.py +++ b/tests/flows/test_oidc-saml.py @@ -1,13 +1,13 @@ -import os -import json import base64 -from urllib.parse import urlparse, urlencode, parse_qsl +import json +import os +from urllib.parse import parse_qsl, urlencode, urlparse import mongomock import pytest -from jwkest.jwk import rsa_load, RSAKey +from jwkest.jwk import RSAKey, rsa_load from jwkest.jws import JWS -from oic.oic.message import ClaimsRequest, Claims +from oic.oic.message import Claims, ClaimsRequest from pyop.storage import StorageBase from saml2 import BINDING_HTTP_REDIRECT from saml2.config import IdPConfig @@ -17,17 +17,16 @@ from satosa.metadata_creation.saml_metadata import create_entity_descriptors from satosa.proxy_server import make_app from satosa.satosa_config import SATOSAConfig -from tests.users import USERS -from tests.users import OIDC_USERS +from tests.users import OIDC_USERS, USERS from tests.util import FakeIdP - CLIENT_ID = "client1" CLIENT_SECRET = "secret" CLIENT_REDIRECT_URI = "https://client.example.com/cb" REDIRECT_URI = "https://client.example.com/cb" DB_URI = "mongodb://localhost/satosa" + @pytest.fixture(scope="session") def client_db_path(tmpdir_factory): tmpdir = str(tmpdir_factory.getbasetemp()) @@ -35,10 +34,8 @@ def client_db_path(tmpdir_factory): cdb_json = { CLIENT_ID: { "response_types": ["id_token", "code"], - "redirect_uris": [ - CLIENT_REDIRECT_URI - ], - "client_secret": CLIENT_SECRET + "redirect_uris": [CLIENT_REDIRECT_URI], + "client_secret": CLIENT_SECRET, } } with open(path, "w") as f: @@ -46,6 +43,7 @@ def client_db_path(tmpdir_factory): return path + @pytest.fixture def oidc_frontend_config(signing_key_path): data = { @@ -56,8 +54,8 @@ def oidc_frontend_config(signing_key_path): "signing_key_path": signing_key_path, "provider": {"response_types_supported": ["id_token"]}, "client_db_uri": DB_URI, # use mongodb for integration testing - "db_uri": DB_URI # use mongodb for integration testing - } + "db_uri": DB_URI, # use mongodb for integration testing + }, } return data @@ -73,16 +71,14 @@ def oidc_stateless_frontend_config(signing_key_path, client_db_path): "signing_key_path": signing_key_path, "client_db_path": client_db_path, "db_uri": "stateless://user:abc123@localhost", - "provider": { - "response_types_supported": ["id_token", "code"] - } - } + "provider": {"response_types_supported": ["id_token", "code"]}, + }, } return data -@mongomock.patch(servers=(('localhost', 27017),)) +@mongomock.patch(servers=(("localhost", 27017),)) class TestOIDCToSAML: def _client_setup(self): """Insert client in mongodb.""" @@ -91,33 +87,51 @@ def _client_setup(self): ) self._cdb[CLIENT_ID] = { "redirect_uris": [REDIRECT_URI], - "response_types": ["id_token"] + "response_types": ["id_token"], } - def test_full_flow(self, satosa_config_dict, oidc_frontend_config, saml_backend_config, idp_conf): + def test_full_flow( + self, satosa_config_dict, oidc_frontend_config, saml_backend_config, idp_conf + ): self._client_setup() subject_id = "testuser1" # proxy config satosa_config_dict["FRONTEND_MODULES"] = [oidc_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] - satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = {attr_name: {"openid": [attr_name], - "saml": [attr_name]} - for attr_name in USERS[subject_id]} - _, backend_metadata = create_entity_descriptors(SATOSAConfig(satosa_config_dict)) + satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = { + attr_name: {"openid": [attr_name], "saml": [attr_name]} + for attr_name in USERS[subject_id] + } + _, backend_metadata = create_entity_descriptors( + SATOSAConfig(satosa_config_dict) + ) # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # get frontend OP config info - provider_config = json.loads(test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) + provider_config = json.loads( + test_client.get("/.well-known/openid-configuration").data.decode("utf-8") + ) # create auth req - claims_request = ClaimsRequest(id_token=Claims(**{k: None for k in USERS[subject_id]})) - req_args = {"scope": "openid", "response_type": "id_token", "client_id": CLIENT_ID, - "redirect_uri": REDIRECT_URI, "nonce": "nonce", - "claims": claims_request.to_json()} - auth_req = urlparse(provider_config["authorization_endpoint"]).path + "?" + urlencode(req_args) + claims_request = ClaimsRequest( + id_token=Claims(**{k: None for k in USERS[subject_id]}) + ) + req_args = { + "scope": "openid", + "response_type": "id_token", + "client_id": CLIENT_ID, + "redirect_uri": REDIRECT_URI, + "nonce": "nonce", + "claims": claims_request.to_json(), + } + auth_req = ( + urlparse(provider_config["authorization_endpoint"]).path + + "?" + + urlencode(req_args) + ) # make auth req to proxy proxied_auth_req = test_client.get(auth_req) @@ -129,13 +143,16 @@ def test_full_flow(self, satosa_config_dict, oidc_frontend_config, saml_backend_ fakeidp = FakeIdP(USERS, config=IdPConfig().load(idp_conf)) # create auth resp - req_params = dict(parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query)) + req_params = dict( + parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query) + ) url, authn_resp = fakeidp.handle_auth_req( req_params["SAMLRequest"], req_params["RelayState"], BINDING_HTTP_REDIRECT, subject_id, - response_binding=BINDING_HTTP_REDIRECT) + response_binding=BINDING_HTTP_REDIRECT, + ) # make auth resp to proxy authn_resp_req = urlparse(url).path + "?" + urlencode(authn_resp) @@ -144,38 +161,65 @@ def test_full_flow(self, satosa_config_dict, oidc_frontend_config, saml_backend_ # verify auth resp from proxy resp_dict = dict(parse_qsl(urlparse(authn_resp.data.decode("utf-8")).fragment)) - signing_key = RSAKey(key=rsa_load(oidc_frontend_config["config"]["signing_key_path"]), - use="sig", alg="RS256") - id_token_claims = JWS().verify_compact(resp_dict["id_token"], keys=[signing_key]) + signing_key = RSAKey( + key=rsa_load(oidc_frontend_config["config"]["signing_key_path"]), + use="sig", + alg="RS256", + ) + id_token_claims = JWS().verify_compact( + resp_dict["id_token"], keys=[signing_key] + ) assert all( (name, values) in id_token_claims.items() for name, values in OIDC_USERS[subject_id].items() ) - def test_full_stateless_id_token_flow(self, satosa_config_dict, oidc_stateless_frontend_config, saml_backend_config, idp_conf): + def test_full_stateless_id_token_flow( + self, + satosa_config_dict, + oidc_stateless_frontend_config, + saml_backend_config, + idp_conf, + ): subject_id = "testuser1" # proxy config satosa_config_dict["FRONTEND_MODULES"] = [oidc_stateless_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] - satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = {attr_name: {"openid": [attr_name], - "saml": [attr_name]} - for attr_name in USERS[subject_id]} - _, backend_metadata = create_entity_descriptors(SATOSAConfig(satosa_config_dict)) + satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = { + attr_name: {"openid": [attr_name], "saml": [attr_name]} + for attr_name in USERS[subject_id] + } + _, backend_metadata = create_entity_descriptors( + SATOSAConfig(satosa_config_dict) + ) # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # get frontend OP config info - provider_config = json.loads(test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) + provider_config = json.loads( + test_client.get("/.well-known/openid-configuration").data.decode("utf-8") + ) # create auth req - claims_request = ClaimsRequest(id_token=Claims(**{k: None for k in USERS[subject_id]})) - req_args = {"scope": "openid", "response_type": "id_token", "client_id": CLIENT_ID, - "redirect_uri": REDIRECT_URI, "nonce": "nonce", - "claims": claims_request.to_json()} - auth_req = urlparse(provider_config["authorization_endpoint"]).path + "?" + urlencode(req_args) + claims_request = ClaimsRequest( + id_token=Claims(**{k: None for k in USERS[subject_id]}) + ) + req_args = { + "scope": "openid", + "response_type": "id_token", + "client_id": CLIENT_ID, + "redirect_uri": REDIRECT_URI, + "nonce": "nonce", + "claims": claims_request.to_json(), + } + auth_req = ( + urlparse(provider_config["authorization_endpoint"]).path + + "?" + + urlencode(req_args) + ) # make auth req to proxy proxied_auth_req = test_client.get(auth_req) @@ -187,13 +231,16 @@ def test_full_stateless_id_token_flow(self, satosa_config_dict, oidc_stateless_f fakeidp = FakeIdP(USERS, config=IdPConfig().load(idp_conf)) # create auth resp - req_params = dict(parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query)) + req_params = dict( + parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query) + ) url, authn_resp = fakeidp.handle_auth_req( req_params["SAMLRequest"], req_params["RelayState"], BINDING_HTTP_REDIRECT, subject_id, - response_binding=BINDING_HTTP_REDIRECT) + response_binding=BINDING_HTTP_REDIRECT, + ) # make auth resp to proxy authn_resp_req = urlparse(url).path + "?" + urlencode(authn_resp) @@ -202,38 +249,65 @@ def test_full_stateless_id_token_flow(self, satosa_config_dict, oidc_stateless_f # verify auth resp from proxy resp_dict = dict(parse_qsl(urlparse(authn_resp.data.decode("utf-8")).fragment)) - signing_key = RSAKey(key=rsa_load(oidc_stateless_frontend_config["config"]["signing_key_path"]), - use="sig", alg="RS256") - id_token_claims = JWS().verify_compact(resp_dict["id_token"], keys=[signing_key]) + signing_key = RSAKey( + key=rsa_load(oidc_stateless_frontend_config["config"]["signing_key_path"]), + use="sig", + alg="RS256", + ) + id_token_claims = JWS().verify_compact( + resp_dict["id_token"], keys=[signing_key] + ) assert all( (name, values) in id_token_claims.items() for name, values in OIDC_USERS[subject_id].items() ) - def test_full_stateless_code_flow(self, satosa_config_dict, oidc_stateless_frontend_config, saml_backend_config, idp_conf): + def test_full_stateless_code_flow( + self, + satosa_config_dict, + oidc_stateless_frontend_config, + saml_backend_config, + idp_conf, + ): subject_id = "testuser1" # proxy config satosa_config_dict["FRONTEND_MODULES"] = [oidc_stateless_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] - satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = {attr_name: {"openid": [attr_name], - "saml": [attr_name]} - for attr_name in USERS[subject_id]} - _, backend_metadata = create_entity_descriptors(SATOSAConfig(satosa_config_dict)) + satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = { + attr_name: {"openid": [attr_name], "saml": [attr_name]} + for attr_name in USERS[subject_id] + } + _, backend_metadata = create_entity_descriptors( + SATOSAConfig(satosa_config_dict) + ) # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # get frontend OP config info - provider_config = json.loads(test_client.get("/.well-known/openid-configuration").data.decode("utf-8")) + provider_config = json.loads( + test_client.get("/.well-known/openid-configuration").data.decode("utf-8") + ) # create auth req - claims_request = ClaimsRequest(id_token=Claims(**{k: None for k in USERS[subject_id]})) - req_args = {"scope": "openid", "response_type": "code", "client_id": CLIENT_ID, - "redirect_uri": REDIRECT_URI, "nonce": "nonce", - "claims": claims_request.to_json()} - auth_req = urlparse(provider_config["authorization_endpoint"]).path + "?" + urlencode(req_args) + claims_request = ClaimsRequest( + id_token=Claims(**{k: None for k in USERS[subject_id]}) + ) + req_args = { + "scope": "openid", + "response_type": "code", + "client_id": CLIENT_ID, + "redirect_uri": REDIRECT_URI, + "nonce": "nonce", + "claims": claims_request.to_json(), + } + auth_req = ( + urlparse(provider_config["authorization_endpoint"]).path + + "?" + + urlencode(req_args) + ) # make auth req to proxy proxied_auth_req = test_client.get(auth_req) @@ -245,13 +319,16 @@ def test_full_stateless_code_flow(self, satosa_config_dict, oidc_stateless_front fakeidp = FakeIdP(USERS, config=IdPConfig().load(idp_conf)) # create auth resp - req_params = dict(parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query)) + req_params = dict( + parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query) + ) url, authn_resp = fakeidp.handle_auth_req( req_params["SAMLRequest"], req_params["RelayState"], BINDING_HTTP_REDIRECT, subject_id, - response_binding=BINDING_HTTP_REDIRECT) + response_binding=BINDING_HTTP_REDIRECT, + ) # make auth resp to proxy authn_resp_req = urlparse(url).path + "?" + urlencode(authn_resp) @@ -261,23 +338,32 @@ def test_full_stateless_code_flow(self, satosa_config_dict, oidc_stateless_front resp_dict = dict(parse_qsl(urlparse(authn_resp.data.decode("utf-8")).query)) code = resp_dict["code"] client_id_secret_str = CLIENT_ID + ":" + CLIENT_SECRET - auth_header = "Basic %s" % base64.b64encode(client_id_secret_str.encode()).decode() + auth_header = ( + "Basic %s" % base64.b64encode(client_id_secret_str.encode()).decode() + ) - authn_resp = test_client.post(provider_config["token_endpoint"], - data={ - "code": code, - "grant_type": "authorization_code", - "redirect_uri": CLIENT_REDIRECT_URI - }, - headers={'Authorization': auth_header}) + authn_resp = test_client.post( + provider_config["token_endpoint"], + data={ + "code": code, + "grant_type": "authorization_code", + "redirect_uri": CLIENT_REDIRECT_URI, + }, + headers={"Authorization": auth_header}, + ) assert authn_resp.status == "200 OK" # verify auth resp from proxy resp_dict = json.loads(authn_resp.data.decode("utf-8")) - signing_key = RSAKey(key=rsa_load(oidc_stateless_frontend_config["config"]["signing_key_path"]), - use="sig", alg="RS256") - id_token_claims = JWS().verify_compact(resp_dict["id_token"], keys=[signing_key]) + signing_key = RSAKey( + key=rsa_load(oidc_stateless_frontend_config["config"]["signing_key_path"]), + use="sig", + alg="RS256", + ) + id_token_claims = JWS().verify_compact( + resp_dict["id_token"], keys=[signing_key] + ) assert all( (name, values) in id_token_claims.items() diff --git a/tests/flows/test_saml-oidc.py b/tests/flows/test_saml-oidc.py index bc41acfe1..75acc7656 100644 --- a/tests/flows/test_saml-oidc.py +++ b/tests/flows/test_saml-oidc.py @@ -1,5 +1,5 @@ import time -from urllib.parse import urlparse, parse_qsl, urlencode +from urllib.parse import parse_qsl, urlencode, urlparse from oic.oic.message import IdToken from saml2 import BINDING_HTTP_REDIRECT @@ -10,21 +10,25 @@ from satosa.metadata_creation.saml_metadata import create_entity_descriptors from satosa.proxy_server import make_app from satosa.satosa_config import SATOSAConfig -from tests.users import USERS -from tests.users import OIDC_USERS +from tests.users import OIDC_USERS, USERS from tests.util import FakeSP class TestSAMLToOIDC: - def run_test(self, satosa_config_dict, sp_conf, oidc_backend_config, frontend_config): + def run_test( + self, satosa_config_dict, sp_conf, oidc_backend_config, frontend_config + ): subject_id = "testuser1" # proxy config satosa_config_dict["FRONTEND_MODULES"] = [frontend_config] satosa_config_dict["BACKEND_MODULES"] = [oidc_backend_config] - satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = {attr_name: {"openid": [attr_name], - "saml": [attr_name]} - for attr_name in USERS[subject_id]} - frontend_metadata, backend_metadata = create_entity_descriptors(SATOSAConfig(satosa_config_dict)) + satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = { + attr_name: {"openid": [attr_name], "saml": [attr_name]} + for attr_name in USERS[subject_id] + } + frontend_metadata, backend_metadata = create_entity_descriptors( + SATOSAConfig(satosa_config_dict) + ) # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) @@ -35,13 +39,17 @@ def run_test(self, satosa_config_dict, sp_conf, oidc_backend_config, frontend_co fakesp = FakeSP(SPConfig().load(sp_conf)) # create auth req - destination, req_args = fakesp.make_auth_req(frontend_metadata[frontend_config["name"]][0].entity_id) + destination, req_args = fakesp.make_auth_req( + frontend_metadata[frontend_config["name"]][0].entity_id + ) auth_req = urlparse(destination).path + "?" + urlencode(req_args) # make auth req to proxy proxied_auth_req = test_client.get(auth_req) assert proxied_auth_req.status == "302 Found" - parsed_auth_req = dict(parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query)) + parsed_auth_req = dict( + parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query) + ) # create auth resp id_token_claims = {k: v for k, v in OIDC_USERS[subject_id].items()} @@ -49,24 +57,37 @@ def run_test(self, satosa_config_dict, sp_conf, oidc_backend_config, frontend_co id_token_claims["iat"] = time.time() id_token_claims["exp"] = time.time() + 3600 id_token_claims["iss"] = "https://op.example.com" - id_token_claims["aud"] = oidc_backend_config["config"]["client"]["client_metadata"]["client_id"] + id_token_claims["aud"] = oidc_backend_config["config"]["client"][ + "client_metadata" + ]["client_id"] id_token_claims["nonce"] = parsed_auth_req["nonce"] id_token = IdToken(**id_token_claims).to_jwt() authn_resp = {"state": parsed_auth_req["state"], "id_token": id_token} # make auth resp to proxy redirect_uri_path = urlparse( - oidc_backend_config["config"]["client"]["client_metadata"]["redirect_uris"][0]).path + oidc_backend_config["config"]["client"]["client_metadata"]["redirect_uris"][ + 0 + ] + ).path authn_resp_req = redirect_uri_path + "?" + urlencode(authn_resp) authn_resp = test_client.get(authn_resp_req) assert authn_resp.status == "303 See Other" # verify auth resp from proxy resp_dict = dict(parse_qsl(urlparse(authn_resp.data.decode("utf-8")).query)) - auth_resp = fakesp.parse_authn_request_response(resp_dict["SAMLResponse"], BINDING_HTTP_REDIRECT) + auth_resp = fakesp.parse_authn_request_response( + resp_dict["SAMLResponse"], BINDING_HTTP_REDIRECT + ) assert auth_resp.ava == USERS[subject_id] - def test_full_flow(self, satosa_config_dict, sp_conf, oidc_backend_config, - saml_frontend_config, saml_mirror_frontend_config): + def test_full_flow( + self, + satosa_config_dict, + sp_conf, + oidc_backend_config, + saml_frontend_config, + saml_mirror_frontend_config, + ): for conf in [saml_frontend_config, saml_mirror_frontend_config]: self.run_test(satosa_config_dict, sp_conf, oidc_backend_config, conf) diff --git a/tests/flows/test_saml-saml.py b/tests/flows/test_saml-saml.py index 91c350495..060d4744a 100644 --- a/tests/flows/test_saml-saml.py +++ b/tests/flows/test_saml-saml.py @@ -1,7 +1,7 @@ -from urllib.parse import parse_qsl, urlparse, urlencode +from urllib.parse import parse_qsl, urlencode, urlparse from saml2 import BINDING_HTTP_REDIRECT -from saml2.config import SPConfig, IdPConfig +from saml2.config import IdPConfig, SPConfig from werkzeug.test import Client from werkzeug.wrappers import Response @@ -9,18 +9,28 @@ from satosa.proxy_server import make_app from satosa.satosa_config import SATOSAConfig from tests.users import USERS -from tests.util import FakeSP, FakeIdP +from tests.util import FakeIdP, FakeSP class TestSAMLToSAML: - def run_test(self, satosa_config_dict, sp_conf, idp_conf, saml_backend_config, frontend_config): + def run_test( + self, + satosa_config_dict, + sp_conf, + idp_conf, + saml_backend_config, + frontend_config, + ): subject_id = "testuser1" # proxy config satosa_config_dict["FRONTEND_MODULES"] = [frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] - satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = {attr_name: {"saml": [attr_name]} for attr_name in - USERS[subject_id]} - frontend_metadata, backend_metadata = create_entity_descriptors(SATOSAConfig(satosa_config_dict)) + satosa_config_dict["INTERNAL_ATTRIBUTES"]["attributes"] = { + attr_name: {"saml": [attr_name]} for attr_name in USERS[subject_id] + } + frontend_metadata, backend_metadata = create_entity_descriptors( + SATOSAConfig(satosa_config_dict) + ) # application test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) @@ -31,7 +41,9 @@ def run_test(self, satosa_config_dict, sp_conf, idp_conf, saml_backend_config, f fakesp = FakeSP(SPConfig().load(sp_conf)) # create auth req - destination, req_args = fakesp.make_auth_req(frontend_metadata[frontend_config["name"]][0].entity_id) + destination, req_args = fakesp.make_auth_req( + frontend_metadata[frontend_config["name"]][0].entity_id + ) auth_req = urlparse(destination).path + "?" + urlencode(req_args) # make auth req to proxy @@ -44,13 +56,16 @@ def run_test(self, satosa_config_dict, sp_conf, idp_conf, saml_backend_config, f fakeidp = FakeIdP(USERS, config=IdPConfig().load(idp_conf)) # create auth resp - req_params = dict(parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query)) + req_params = dict( + parse_qsl(urlparse(proxied_auth_req.data.decode("utf-8")).query) + ) url, authn_resp = fakeidp.handle_auth_req( req_params["SAMLRequest"], req_params["RelayState"], BINDING_HTTP_REDIRECT, subject_id, - response_binding=BINDING_HTTP_REDIRECT) + response_binding=BINDING_HTTP_REDIRECT, + ) # make auth resp to proxy authn_resp_req = urlparse(url).path + "?" + urlencode(authn_resp) @@ -59,10 +74,21 @@ def run_test(self, satosa_config_dict, sp_conf, idp_conf, saml_backend_config, f # verify auth resp from proxy resp_dict = dict(parse_qsl(urlparse(authn_resp.data.decode("utf-8")).query)) - auth_resp = fakesp.parse_authn_request_response(resp_dict["SAMLResponse"], BINDING_HTTP_REDIRECT) + auth_resp = fakesp.parse_authn_request_response( + resp_dict["SAMLResponse"], BINDING_HTTP_REDIRECT + ) assert auth_resp.ava == USERS[subject_id] - def test_full_flow(self, satosa_config_dict, sp_conf, idp_conf, saml_backend_config, - saml_frontend_config, saml_mirror_frontend_config): + def test_full_flow( + self, + satosa_config_dict, + sp_conf, + idp_conf, + saml_backend_config, + saml_frontend_config, + saml_mirror_frontend_config, + ): for conf in [saml_frontend_config, saml_mirror_frontend_config]: - self.run_test(satosa_config_dict, sp_conf, idp_conf, saml_backend_config, conf) + self.run_test( + satosa_config_dict, sp_conf, idp_conf, saml_backend_config, conf + ) diff --git a/tests/flows/test_wsgi_flow.py b/tests/flows/test_wsgi_flow.py index ab9d636f5..3f2298046 100644 --- a/tests/flows/test_wsgi_flow.py +++ b/tests/flows/test_wsgi_flow.py @@ -22,18 +22,24 @@ def test_flow(self, satosa_config_dict): test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) # Make request to frontend - resp = test_client.get('/{}/{}/request'.format("backend", "frontend")) - assert resp.status == '200 OK' + resp = test_client.get("/{}/{}/request".format("backend", "frontend")) + assert resp.status == "200 OK" headers = dict(resp.headers) assert headers["Set-Cookie"] # Fake response coming in to backend - resp = test_client.get('/{}/response'.format("backend"), headers=[("Cookie", headers["Set-Cookie"])]) - assert resp.status == '200 OK' - assert resp.data.decode('utf-8') == "Auth response received, passed to test frontend" + resp = test_client.get( + "/{}/response".format("backend"), + headers=[("Cookie", headers["Set-Cookie"])], + ) + assert resp.status == "200 OK" + assert ( + resp.data.decode("utf-8") + == "Auth response received, passed to test frontend" + ) def test_unknown_request_path(self, satosa_config_dict): test_client = Client(make_app(SATOSAConfig(satosa_config_dict)), Response) - resp = test_client.get('/unknown') + resp = test_client.get("/unknown") assert resp.status == NotFound._status diff --git a/tests/satosa/backends/test_bitbucket.py b/tests/satosa/backends/test_bitbucket.py index d6cf25bac..cfe397624 100644 --- a/tests/satosa/backends/test_bitbucket.py +++ b/tests/satosa/backends/test_bitbucket.py @@ -1,10 +1,9 @@ import json from unittest.mock import Mock -from urllib.parse import urlparse, parse_qsl +from urllib.parse import parse_qsl, urlparse import pytest import responses - from saml2.saml import NAMEID_FORMAT_TRANSIENT from satosa.backends.bitbucket import BitBucketBackend @@ -17,54 +16,48 @@ "nickname": "bb_username", "display_name": "bb_first_name bb_last_name", "has_2fa_enabled": False, - "created_on": "2019-10-12T09:14:00+0000" + "created_on": "2019-10-12T09:14:00+0000", } BB_USER_EMAIL_RESPONSE = { "values": [ - { - "email": "bb_username@example.com", - "is_confirmed": True, - "is_primary": True - }, - { + {"email": "bb_username@example.com", "is_confirmed": True, "is_primary": True}, + { "email": "bb_username_1@example.com", "is_confirmed": True, - "is_primary": False + "is_primary": False, }, - { + { "email": "bb_username_2@example.com", "is_confirmed": False, - "is_primary": False - } + "is_primary": False, + }, ] } BASE_URL = "https://client.example.com" -AUTHZ_PAGE = 'bitbucket' +AUTHZ_PAGE = "bitbucket" CLIENT_ID = "bitbucket_client_id" BB_CONFIG = { - 'server_info': { - 'authorization_endpoint': - 'https://bitbucket.org/site/oauth2/authorize', - 'token_endpoint': 'https://bitbucket.org/site/oauth2/access_token', - 'user_endpoint': 'https://api.bitbucket.org/2.0/user' + "server_info": { + "authorization_endpoint": "https://bitbucket.org/site/oauth2/authorize", + "token_endpoint": "https://bitbucket.org/site/oauth2/access_token", + "user_endpoint": "https://api.bitbucket.org/2.0/user", }, - 'client_secret': 'bitbucket_secret', - 'base_url': BASE_URL, - 'state_encryption_ key': 'state_encryption_key', - 'encryption_key': 'encryption_key', - 'authz_page': AUTHZ_PAGE, - 'client_config': {'client_id': CLIENT_ID}, - 'scope': ["account", "email"] - + "client_secret": "bitbucket_secret", + "base_url": BASE_URL, + "state_encryption_ key": "state_encryption_key", + "encryption_key": "encryption_key", + "authz_page": AUTHZ_PAGE, + "client_config": {"client_id": CLIENT_ID}, + "scope": ["account", "email"], } BB_RESPONSE_CODE = "the_bb_code" INTERNAL_ATTRIBUTES = { - 'attributes': { - 'mail': {'bitbucket': ['email']}, - 'subject-id': {'bitbucket': ['account_id']}, - 'displayname': {'bitbucket': ['display_name']}, - 'name': {'bitbucket': ['display_name']}, + "attributes": { + "mail": {"bitbucket": ["email"]}, + "subject-id": {"bitbucket": ["account_id"]}, + "displayname": {"bitbucket": ["display_name"]}, + "name": {"bitbucket": ["display_name"]}, } } @@ -74,34 +67,39 @@ class TestBitBucketBackend(object): @pytest.fixture(autouse=True) def create_backend(self): - self.bb_backend = BitBucketBackend(Mock(), INTERNAL_ATTRIBUTES, - BB_CONFIG, "base_url", "bitbucket") + self.bb_backend = BitBucketBackend( + Mock(), INTERNAL_ATTRIBUTES, BB_CONFIG, "base_url", "bitbucket" + ) @pytest.fixture def incoming_authn_response(self, context): - context.path = 'bitbucket/sso/redirect' + context.path = "bitbucket/sso/redirect" state_data = dict(state=mock_get_state.return_value) context.state[self.bb_backend.name] = state_data context.request = { "code": BB_RESPONSE_CODE, - "state": mock_get_state.return_value + "state": mock_get_state.return_value, } return context def setup_bitbucket_response(self): - _user_endpoint = BB_CONFIG['server_info']['user_endpoint'] - responses.add(responses.GET, - _user_endpoint, - body=json.dumps(BB_USER_RESPONSE), - status=200, - content_type='application/json') - - responses.add(responses.GET, - '{}/emails'.format(_user_endpoint), - body=json.dumps(BB_USER_EMAIL_RESPONSE), - status=200, - content_type='application/json') + _user_endpoint = BB_CONFIG["server_info"]["user_endpoint"] + responses.add( + responses.GET, + _user_endpoint, + body=json.dumps(BB_USER_RESPONSE), + status=200, + content_type="application/json", + ) + + responses.add( + responses.GET, + "{}/emails".format(_user_endpoint), + body=json.dumps(BB_USER_EMAIL_RESPONSE), + status=200, + content_type="application/json", + ) def assert_expected_attributes(self): expected_attributes = { @@ -111,9 +109,7 @@ def assert_expected_attributes(self): "mail": [BB_USER_EMAIL_RESPONSE["values"][0]["email"]], } - context, internal_resp = self.bb_backend \ - .auth_callback_func \ - .call_args[0] + context, internal_resp = self.bb_backend.auth_callback_func.call_args[0] assert internal_resp.attributes == expected_attributes def assert_token_request(self, request_args, state, **kwargs): @@ -124,27 +120,24 @@ def assert_token_request(self, request_args, state, **kwargs): def test_register_endpoints(self): url_map = self.bb_backend.register_endpoints() - expected_url_map = [('^bitbucket$', self.bb_backend._authn_response)] + expected_url_map = [("^bitbucket$", self.bb_backend._authn_response)] assert url_map == expected_url_map def test_start_auth(self, context): - context.path = 'bitbucket/sso/redirect' + context.path = "bitbucket/sso/redirect" internal_request = InternalData( - subject_type=NAMEID_FORMAT_TRANSIENT, requester='test_requester' + subject_type=NAMEID_FORMAT_TRANSIENT, requester="test_requester" ) - resp = self.bb_backend.start_auth(context, - internal_request, - mock_get_state) + resp = self.bb_backend.start_auth(context, internal_request, mock_get_state) login_url = resp.message - assert login_url.startswith( - BB_CONFIG["server_info"]["authorization_endpoint"]) + assert login_url.startswith(BB_CONFIG["server_info"]["authorization_endpoint"]) expected_params = { "client_id": CLIENT_ID, "state": mock_get_state.return_value, "response_type": "code", "scope": " ".join(BB_CONFIG["scope"]), - "redirect_uri": "%s/%s" % (BASE_URL, AUTHZ_PAGE) + "redirect_uri": "%s/%s" % (BASE_URL, AUTHZ_PAGE), } actual_params = dict(parse_qsl(urlparse(login_url).query)) assert actual_params == expected_params @@ -154,9 +147,9 @@ def test_authn_response(self, incoming_authn_response): self.setup_bitbucket_response() mock_do_access_token_request = Mock( - return_value={"access_token": "bb access token"}) - self.bb_backend.consumer.do_access_token_request = \ - mock_do_access_token_request + return_value={"access_token": "bb access token"} + ) + self.bb_backend.consumer.do_access_token_request = mock_do_access_token_request self.bb_backend._authn_response(incoming_authn_response) @@ -169,24 +162,30 @@ def test_entire_flow(self, context): Tests start of authentication (incoming auth req) and receiving auth response. """ - responses.add(responses.POST, - BB_CONFIG["server_info"]["token_endpoint"], - body=json.dumps({"access_token": "qwerty", - "token_type": "bearer", - "expires_in": 9999999999999}), - status=200, - content_type='application/json') + responses.add( + responses.POST, + BB_CONFIG["server_info"]["token_endpoint"], + body=json.dumps( + { + "access_token": "qwerty", + "token_type": "bearer", + "expires_in": 9999999999999, + } + ), + status=200, + content_type="application/json", + ) self.setup_bitbucket_response() - context.path = 'bitbucket/sso/redirect' + context.path = "bitbucket/sso/redirect" internal_request = InternalData( - subject_type=NAMEID_FORMAT_TRANSIENT, requester='test_requester' + subject_type=NAMEID_FORMAT_TRANSIENT, requester="test_requester" ) self.bb_backend.start_auth(context, internal_request, mock_get_state) context.request = { "code": BB_RESPONSE_CODE, - "state": mock_get_state.return_value + "state": mock_get_state.return_value, } self.bb_backend._authn_response(context) self.assert_expected_attributes() diff --git a/tests/satosa/backends/test_idpy_oidc.py b/tests/satosa/backends/test_idpy_oidc.py index 373f59365..b3d75a00b 100644 --- a/tests/satosa/backends/test_idpy_oidc.py +++ b/tests/satosa/backends/test_idpy_oidc.py @@ -3,17 +3,15 @@ import time from datetime import datetime from unittest.mock import Mock -from urllib.parse import parse_qsl -from urllib.parse import urlparse +from urllib.parse import parse_qsl, urlparse +import pytest +import responses from cryptojwt.key_jar import build_keyjar from idpyoidc.client.defaults import DEFAULT_KEY_DEFS from idpyoidc.client.oauth2.stand_alone_client import StandAloneClient -from idpyoidc.message.oidc import AuthorizationResponse -from idpyoidc.message.oidc import IdToken +from idpyoidc.message.oidc import AuthorizationResponse, IdToken from oic.oic import AuthorizationRequest -import pytest -import responses from satosa.backends.idpy_oidc import IdpyOIDCBackend from satosa.context import Context @@ -48,8 +46,8 @@ def backend_config(self): "authorization_endpoint": f"{ISSUER}/authn", "token_endpoint": f"{ISSUER}/token", "userinfo_endpoint": f"{ISSUER}/user", - "jwks_uri": f"{ISSUER}/static/jwks" - } + "jwks_uri": f"{ISSUER}/static/jwks", + }, } } @@ -60,25 +58,27 @@ def internal_attributes(self): "givenname": {"openid": ["given_name"]}, "mail": {"openid": ["email"]}, "edupersontargetedid": {"openid": ["sub"]}, - "surname": {"openid": ["family_name"]} + "surname": {"openid": ["family_name"]}, } } @pytest.fixture(autouse=True) @responses.activate def create_backend(self, internal_attributes, backend_config): - base_url = backend_config['client']['base_url'] + base_url = backend_config["client"]["base_url"] self.issuer_keys = build_keyjar(DEFAULT_KEY_DEFS) with responses.RequestsMock() as rsps: rsps.add( responses.GET, - backend_config['client']['provider_info']['jwks_uri'], + backend_config["client"]["provider_info"]["jwks_uri"], body=self.issuer_keys.export_jwks_as_json(), status=200, - content_type="application/json") + content_type="application/json", + ) - self.oidc_backend = IdpyOIDCBackend(Mock(), internal_attributes, backend_config, - base_url, "oidc") + self.oidc_backend = IdpyOIDCBackend( + Mock(), internal_attributes, backend_config, base_url, "oidc" + ) @pytest.fixture def userinfo(self): @@ -86,13 +86,13 @@ def userinfo(self): "given_name": "Test", "family_name": "Devsson", "email": "test_dev@example.com", - "sub": "username" + "sub": "username", } @pytest.fixture def id_token(self, userinfo): issuer_keys = build_keyjar(DEFAULT_KEY_DEFS) - signing_key = issuer_keys.get_signing_key(key_type='RSA')[0] + signing_key = issuer_keys.get_signing_key(key_type="RSA")[0] signing_key.alg = "RS256" auth_time = int(datetime.utcnow().timestamp()) id_token_claims = { @@ -116,7 +116,10 @@ def test_client(self, backend_config): assert isinstance(self.oidc_backend.client, StandAloneClient) # 3 signing keys. One RSA, one EC and one symmetric assert len(self.oidc_backend.client.context.keyjar.get_signing_key()) == 3 - assert self.oidc_backend.client.context.jwks_uri == backend_config['client']['jwks_uri'] + assert ( + self.oidc_backend.client.context.jwks_uri + == backend_config["client"]["jwks_uri"] + ) def assert_expected_attributes(self, attr_map, user_claims, actual_attributes): expected_attributes = { @@ -127,7 +130,7 @@ def assert_expected_attributes(self, attr_map, user_claims, actual_attributes): def setup_token_endpoint(self, userinfo): _client = self.oidc_backend.client - signing_key = self.issuer_keys.get_signing_key(key_type='RSA')[0] + signing_key = self.issuer_keys.get_signing_key(key_type="RSA")[0] signing_key.alg = "RS256" id_token_claims = { "iss": ISSUER, @@ -135,28 +138,34 @@ def setup_token_endpoint(self, userinfo): "aud": CLIENT_ID, "nonce": NONCE, "exp": time.time() + 3600, - "iat": time.time() + "iat": time.time(), } - id_token = IdToken(**id_token_claims).to_jwt([signing_key], algorithm=signing_key.alg) + id_token = IdToken(**id_token_claims).to_jwt( + [signing_key], algorithm=signing_key.alg + ) token_response = { "access_token": "SlAV32hkKG", "token_type": "Bearer", "refresh_token": "8xLOxBtZp8", "expires_in": 3600, - "id_token": id_token + "id_token": id_token, } - responses.add(responses.POST, - _client.context.provider_info['token_endpoint'], - body=json.dumps(token_response), - status=200, - content_type="application/json") + responses.add( + responses.POST, + _client.context.provider_info["token_endpoint"], + body=json.dumps(token_response), + status=200, + content_type="application/json", + ) def setup_userinfo_endpoint(self, userinfo): - responses.add(responses.GET, - self.oidc_backend.client.context.provider_info['userinfo_endpoint'], - body=json.dumps(userinfo), - status=200, - content_type="application/json") + responses.add( + responses.GET, + self.oidc_backend.client.context.provider_info["userinfo_endpoint"], + body=json.dumps(userinfo), + status=200, + content_type="application/json", + ) @pytest.fixture def incoming_authn_response(self): @@ -168,7 +177,7 @@ def incoming_authn_response(self): response_type="code", client_id=_context.get_client_id(), scope=_context.claims.get_usage("scope"), - nonce=NONCE + nonce=NONCE, ) _context.cstate.set(oidc_state, {"iss": _context.issuer}) _context.cstate.bind_key(NONCE, oidc_state) @@ -178,20 +187,22 @@ def incoming_authn_response(self): code="F+R4uWbN46U+Bq9moQPC4lEvRd2De4o=", state=oidc_state, iss=_context.issuer, - nonce=NONCE + nonce=NONCE, ) return response.to_dict() def test_register_endpoints(self): _uri = self.oidc_backend.client.context.claims.get_usage("redirect_uris")[0] - redirect_uri_path = urlparse(_uri).path.lstrip('/') + redirect_uri_path = urlparse(_uri).path.lstrip("/") url_map = self.oidc_backend.register_endpoints() regex, callback = url_map[0] assert re.search(regex, redirect_uri_path) assert callback == self.oidc_backend.response_endpoint def test_translate_response_to_internal_response(self, all_user_claims): - internal_response = self.oidc_backend._translate_response(all_user_claims, ISSUER) + internal_response = self.oidc_backend._translate_response( + all_user_claims, ISSUER + ) assert internal_response.subject_id == all_user_claims["sub"] self.assert_expected_attributes( self.oidc_backend.internal_attributes, @@ -224,11 +235,21 @@ def test_start_auth_redirects_to_provider_authorization_endpoint(self, context): login_url = auth_response.message parsed = urlparse(login_url) - assert login_url.startswith(_client.context.provider_info["authorization_endpoint"]) + assert login_url.startswith( + _client.context.provider_info["authorization_endpoint"] + ) auth_params = dict(parse_qsl(parsed.query)) - assert auth_params["scope"] == " ".join(_client.context.claims.get_usage("scope")) - assert auth_params["response_type"] == _client.context.claims.get_usage("response_types")[0] + assert auth_params["scope"] == " ".join( + _client.context.claims.get_usage("scope") + ) + assert ( + auth_params["response_type"] + == _client.context.claims.get_usage("response_types")[0] + ) assert auth_params["client_id"] == _client.client_id - assert auth_params["redirect_uri"] == _client.context.claims.get_usage("redirect_uris")[0] + assert ( + auth_params["redirect_uri"] + == _client.context.claims.get_usage("redirect_uris")[0] + ) assert "state" in auth_params assert "nonce" in auth_params diff --git a/tests/satosa/backends/test_oauth.py b/tests/satosa/backends/test_oauth.py index 22afc8ee7..432423324 100644 --- a/tests/satosa/backends/test_oauth.py +++ b/tests/satosa/backends/test_oauth.py @@ -1,10 +1,9 @@ import json from unittest.mock import Mock -from urllib.parse import urlparse, parse_qsl +from urllib.parse import parse_qsl, urlparse import pytest import responses - from saml2.saml import NAMEID_FORMAT_TRANSIENT from satosa.backends.oauth import FacebookBackend @@ -15,47 +14,54 @@ "name": "fb_name", "first_name": "fb_first_name", "last_name": "fb_last_name", - "picture": { - "data": { - "is_silhouette": False, - "url": "fb_picture" - } - }, + "picture": {"data": {"is_silhouette": False, "url": "fb_picture"}}, "email": "fb_email", "verified": True, "gender": "fb_gender", "timezone": 2, "locale": "sv_SE", - "updated_time": "2015-10-15T07:04:10+0000" + "updated_time": "2015-10-15T07:04:10+0000", } BASE_URL = "https://client.example.com" -AUTHZ_PAGE = 'facebook' +AUTHZ_PAGE = "facebook" CLIENT_ID = "facebook_client_id" FB_AUTH_ENDPOINT = "https://www.facebook.com/dialog/oauth" FB_CONFIG = { - 'server_info': { - 'authorization_endpoint': FB_AUTH_ENDPOINT, - 'token_endpoint': 'https://graph.facebook.com/v2.5/oauth/access_token' + "server_info": { + "authorization_endpoint": FB_AUTH_ENDPOINT, + "token_endpoint": "https://graph.facebook.com/v2.5/oauth/access_token", }, - 'client_secret': 'facebook_secret', - 'base_url': BASE_URL, - 'state_encryption_key': 'state_encryption_key', - 'encryption_key': 'encryption_key', - 'fields': ['id', 'name', 'first_name', 'last_name', 'middle_name', 'picture', 'email', - 'verified', 'gender', 'timezone', 'locale', 'updated_time'], - 'authz_page': AUTHZ_PAGE, - 'client_config': {'client_id': CLIENT_ID} + "client_secret": "facebook_secret", + "base_url": BASE_URL, + "state_encryption_key": "state_encryption_key", + "encryption_key": "encryption_key", + "fields": [ + "id", + "name", + "first_name", + "last_name", + "middle_name", + "picture", + "email", + "verified", + "gender", + "timezone", + "locale", + "updated_time", + ], + "authz_page": AUTHZ_PAGE, + "client_config": {"client_id": CLIENT_ID}, } FB_RESPONSE_CODE = "the_fb_code" INTERNAL_ATTRIBUTES = { - 'attributes': { - 'givenname': {'facebook': ['first_name']}, - 'mail': {'facebook': ['email']}, - 'edupersontargetedid': {'facebook': ['id']}, - 'name': {'facebook': ['name']}, - 'surname': {'facebook': ['last_name']}, - 'gender': {'facebook': ['gender']} + "attributes": { + "givenname": {"facebook": ["first_name"]}, + "mail": {"facebook": ["email"]}, + "edupersontargetedid": {"facebook": ["id"]}, + "name": {"facebook": ["name"]}, + "surname": {"facebook": ["last_name"]}, + "gender": {"facebook": ["gender"]}, } } @@ -65,26 +71,30 @@ class TestFacebookBackend(object): @pytest.fixture(autouse=True) def create_backend(self): - self.fb_backend = FacebookBackend(Mock(), INTERNAL_ATTRIBUTES, FB_CONFIG, "base_url", "facebook") + self.fb_backend = FacebookBackend( + Mock(), INTERNAL_ATTRIBUTES, FB_CONFIG, "base_url", "facebook" + ) @pytest.fixture def incoming_authn_response(self, context): - context.path = 'facebook/sso/redirect' + context.path = "facebook/sso/redirect" state_data = dict(state=mock_get_state.return_value) context.state[self.fb_backend.name] = state_data context.request = { "code": FB_RESPONSE_CODE, - "state": mock_get_state.return_value + "state": mock_get_state.return_value, } return context def setup_facebook_response(self): - responses.add(responses.GET, - "https://graph.facebook.com/v2.5/me", - body=json.dumps(FB_RESPONSE), - status=200, - content_type='application/json') + responses.add( + responses.GET, + "https://graph.facebook.com/v2.5/me", + body=json.dumps(FB_RESPONSE), + status=200, + content_type="application/json", + ) def assert_expected_attributes(self): expected_attributes = { @@ -107,13 +117,13 @@ def assert_token_request(self, request_args, state, **kwargs): def test_register_endpoints(self): url_map = self.fb_backend.register_endpoints() - expected_url_map = [('^facebook$', self.fb_backend._authn_response)] + expected_url_map = [("^facebook$", self.fb_backend._authn_response)] assert url_map == expected_url_map def test_start_auth(self, context): - context.path = 'facebook/sso/redirect' + context.path = "facebook/sso/redirect" internal_request = InternalData( - subject_type=NAMEID_FORMAT_TRANSIENT, requester='test_requester' + subject_type=NAMEID_FORMAT_TRANSIENT, requester="test_requester" ) resp = self.fb_backend.start_auth(context, internal_request, mock_get_state) @@ -123,7 +133,7 @@ def test_start_auth(self, context): "client_id": CLIENT_ID, "state": mock_get_state.return_value, "response_type": "code", - "redirect_uri": "%s/%s" % (BASE_URL, AUTHZ_PAGE) + "redirect_uri": "%s/%s" % (BASE_URL, AUTHZ_PAGE), } actual_params = dict(parse_qsl(urlparse(login_url).query)) assert actual_params == expected_params @@ -132,7 +142,9 @@ def test_start_auth(self, context): def test_authn_response(self, incoming_authn_response): self.setup_facebook_response() - mock_do_access_token_request = Mock(return_value={"access_token": "fb access token"}) + mock_do_access_token_request = Mock( + return_value={"access_token": "fb access token"} + ) self.fb_backend.consumer.do_access_token_request = mock_do_access_token_request self.fb_backend._authn_response(incoming_authn_response) @@ -143,24 +155,30 @@ def test_authn_response(self, incoming_authn_response): @responses.activate def test_entire_flow(self, context): """Tests start of authentication (incoming auth req) and receiving auth response.""" - responses.add(responses.POST, - "https://graph.facebook.com/v2.5/oauth/access_token", - body=json.dumps({"access_token": "qwerty", - "token_type": "bearer", - "expires_in": 9999999999999}), - status=200, - content_type='application/json') + responses.add( + responses.POST, + "https://graph.facebook.com/v2.5/oauth/access_token", + body=json.dumps( + { + "access_token": "qwerty", + "token_type": "bearer", + "expires_in": 9999999999999, + } + ), + status=200, + content_type="application/json", + ) self.setup_facebook_response() - context.path = 'facebook/sso/redirect' + context.path = "facebook/sso/redirect" internal_request = InternalData( - subject_type=NAMEID_FORMAT_TRANSIENT, requester='test_requester' + subject_type=NAMEID_FORMAT_TRANSIENT, requester="test_requester" ) self.fb_backend.start_auth(context, internal_request, mock_get_state) context.request = { "code": FB_RESPONSE_CODE, - "state": mock_get_state.return_value + "state": mock_get_state.return_value, } self.fb_backend._authn_response(context) self.assert_expected_attributes() diff --git a/tests/satosa/backends/test_openid_connect.py b/tests/satosa/backends/test_openid_connect.py index 34bac79fe..c28522d1c 100644 --- a/tests/satosa/backends/test_openid_connect.py +++ b/tests/satosa/backends/test_openid_connect.py @@ -2,7 +2,7 @@ import re import time from unittest.mock import Mock -from urllib.parse import urlparse, parse_qsl +from urllib.parse import parse_qsl, urlparse import oic import pytest @@ -12,7 +12,12 @@ from oic.oic.message import IdToken from oic.utils.authn.client import CLIENT_AUTHN_METHOD -from satosa.backends.openid_connect import OpenIDConnectBackend, _create_client, STATE_KEY, NONCE_KEY +from satosa.backends.openid_connect import ( + NONCE_KEY, + STATE_KEY, + OpenIDConnectBackend, + _create_client, +) from satosa.context import Context from satosa.internal import InternalData from satosa.response import Response @@ -25,7 +30,9 @@ class TestOpenIDConnectBackend(object): @pytest.fixture(autouse=True) def create_backend(self, internal_attributes, backend_config): - self.oidc_backend = OpenIDConnectBackend(Mock(), internal_attributes, backend_config, "base_url", "oidc") + self.oidc_backend = OpenIDConnectBackend( + Mock(), internal_attributes, backend_config, "base_url", "oidc" + ) @pytest.fixture def backend_config(self): @@ -39,12 +46,12 @@ def backend_config(self): "contacts": ["ops@example.com"], "redirect_uris": ["https://client.test.com/authz_cb"], "response_types": ["code"], - "subject_type": "pairwise" + "subject_type": "pairwise", }, "auth_req_params": { "response_type": "code id_token token", - "scope": "openid foo" - } + "scope": "openid foo", + }, }, "provider_metadata": { "issuer": ISSUER, @@ -52,8 +59,8 @@ def backend_config(self): "token_endpoint": ISSUER + "/token", "userinfo_endpoint": ISSUER + "/userinfo", "registration_endpoint": ISSUER + "/registration", - "jwks_uri": ISSUER + "/static/jwks" - } + "jwks_uri": ISSUER + "/static/jwks", + }, } @pytest.fixture @@ -63,7 +70,7 @@ def internal_attributes(self): "givenname": {"openid": ["given_name"]}, "mail": {"openid": ["email"]}, "edupersontargetedid": {"openid": ["sub"]}, - "surname": {"openid": ["family_name"]} + "surname": {"openid": ["family_name"]}, } } @@ -73,7 +80,7 @@ def userinfo(self): "given_name": "Test", "family_name": "Devsson", "email": "test_dev@example.com", - "sub": "username" + "sub": "username", } @pytest.fixture(scope="session") @@ -93,7 +100,8 @@ def setup_jwks_uri(self, jwks_uri, key): jwks_uri, body=json.dumps({"keys": [key.serialize()]}), status=200, - content_type="application/json") + content_type="application/json", + ) def setup_token_endpoint(self, token_endpoint_url, userinfo, signing_key): id_token_claims = { @@ -102,7 +110,7 @@ def setup_token_endpoint(self, token_endpoint_url, userinfo, signing_key): "aud": CLIENT_ID, "nonce": NONCE, "exp": time.time() + 3600, - "iat": time.time() + "iat": time.time(), } id_token = IdToken(**id_token_claims).to_jwt([signing_key], signing_key.alg) token_response = { @@ -110,23 +118,29 @@ def setup_token_endpoint(self, token_endpoint_url, userinfo, signing_key): "token_type": "Bearer", "refresh_token": "8xLOxBtZp8", "expires_in": 3600, - "id_token": id_token + "id_token": id_token, } - responses.add(responses.POST, - token_endpoint_url, - body=json.dumps(token_response), - status=200, - content_type="application/json") + responses.add( + responses.POST, + token_endpoint_url, + body=json.dumps(token_response), + status=200, + content_type="application/json", + ) def setup_userinfo_endpoint(self, userinfo_endpoint_url, userinfo): - responses.add(responses.GET, - userinfo_endpoint_url, - body=json.dumps(userinfo), - status=200, - content_type="application/json") + responses.add( + responses.GET, + userinfo_endpoint_url, + body=json.dumps(userinfo), + status=200, + content_type="application/json", + ) def get_redirect_uri_path(self, backend_config): - return urlparse(backend_config["client"]["client_metadata"]["redirect_uris"][0]).path.lstrip("/") + return urlparse( + backend_config["client"]["client_metadata"]["redirect_uris"][0] + ).path.lstrip("/") @pytest.fixture def incoming_authn_response(self, context, backend_config): @@ -134,13 +148,10 @@ def incoming_authn_response(self, context, backend_config): context.path = self.get_redirect_uri_path(backend_config) context.request = { "code": "F+R4uWbN46U+Bq9moQPC4lEvRd2De4o=", - "state": oidc_state + "state": oidc_state, } - state_data = { - STATE_KEY: oidc_state, - NONCE_KEY: NONCE - } + state_data = {STATE_KEY: oidc_state, NONCE_KEY: NONCE} context.state[self.oidc_backend.name] = state_data return context @@ -151,42 +162,78 @@ def test_register_endpoints(self, backend_config): assert re.search(regex, redirect_uri_path) assert callback == self.oidc_backend.response_endpoint - def test_translate_response_to_internal_response(self, internal_attributes, userinfo): + def test_translate_response_to_internal_response( + self, internal_attributes, userinfo + ): internal_response = self.oidc_backend._translate_response(userinfo, ISSUER) assert internal_response.subject_id == userinfo["sub"] - self.assert_expected_attributes(internal_attributes, userinfo, internal_response.attributes) + self.assert_expected_attributes( + internal_attributes, userinfo, internal_response.attributes + ) @responses.activate - def test_response_endpoint(self, backend_config, internal_attributes, userinfo, signing_key, incoming_authn_response): - self.setup_jwks_uri(backend_config["provider_metadata"]["jwks_uri"], signing_key) - self.setup_token_endpoint(backend_config["provider_metadata"]["token_endpoint"], userinfo, signing_key) - self.setup_userinfo_endpoint(backend_config["provider_metadata"]["userinfo_endpoint"], userinfo) + def test_response_endpoint( + self, + backend_config, + internal_attributes, + userinfo, + signing_key, + incoming_authn_response, + ): + self.setup_jwks_uri( + backend_config["provider_metadata"]["jwks_uri"], signing_key + ) + self.setup_token_endpoint( + backend_config["provider_metadata"]["token_endpoint"], userinfo, signing_key + ) + self.setup_userinfo_endpoint( + backend_config["provider_metadata"]["userinfo_endpoint"], userinfo + ) self.oidc_backend.response_endpoint(incoming_authn_response) args = self.oidc_backend.auth_callback_func.call_args[0] assert isinstance(args[0], Context) assert isinstance(args[1], InternalData) - self.assert_expected_attributes(internal_attributes, userinfo, args[1].attributes) + self.assert_expected_attributes( + internal_attributes, userinfo, args[1].attributes + ) - def test_start_auth_redirects_to_provider_authorization_endpoint(self, context, backend_config): + def test_start_auth_redirects_to_provider_authorization_endpoint( + self, context, backend_config + ): auth_response = self.oidc_backend.start_auth(context, None) assert isinstance(auth_response, Response) login_url = auth_response.message parsed = urlparse(login_url) - assert login_url.startswith(backend_config["provider_metadata"]["authorization_endpoint"]) + assert login_url.startswith( + backend_config["provider_metadata"]["authorization_endpoint"] + ) auth_params = dict(parse_qsl(parsed.query)) - assert auth_params["scope"] == backend_config["client"]["auth_req_params"]["scope"] - assert auth_params["response_type"] == backend_config["client"]["auth_req_params"]["response_type"] - assert auth_params["client_id"] == backend_config["client"]["client_metadata"]["client_id"] - assert auth_params["redirect_uri"] == backend_config["client"]["client_metadata"]["redirect_uris"][0] + assert ( + auth_params["scope"] == backend_config["client"]["auth_req_params"]["scope"] + ) + assert ( + auth_params["response_type"] + == backend_config["client"]["auth_req_params"]["response_type"] + ) + assert ( + auth_params["client_id"] + == backend_config["client"]["client_metadata"]["client_id"] + ) + assert ( + auth_params["redirect_uri"] + == backend_config["client"]["client_metadata"]["redirect_uris"][0] + ) assert "state" in auth_params assert "nonce" in auth_params @responses.activate def test_entire_flow(self, context, backend_config, internal_attributes, userinfo): - self.setup_userinfo_endpoint(backend_config["provider_metadata"]["userinfo_endpoint"], userinfo) + self.setup_userinfo_endpoint( + backend_config["provider_metadata"]["userinfo_endpoint"], userinfo + ) auth_response = self.oidc_backend.start_auth(context, None) auth_params = dict(parse_qsl(urlparse(auth_response.message).query)) @@ -198,7 +245,9 @@ def test_entire_flow(self, context, backend_config, internal_attributes, userinf } self.oidc_backend.response_endpoint(context) args = self.oidc_backend.auth_callback_func.call_args[0] - self.assert_expected_attributes(internal_attributes, userinfo, args[1].attributes) + self.assert_expected_attributes( + internal_attributes, userinfo, args[1].attributes + ) class TestCreateClient(object): @@ -208,7 +257,7 @@ def provider_metadata(self): "issuer": ISSUER, "authorization_endpoint": ISSUER + "/authorization", "token_endpoint": ISSUER + "/token", - "registration_endpoint": ISSUER + "/registration" + "registration_endpoint": ISSUER + "/registration", } @pytest.fixture @@ -217,9 +266,10 @@ def client_metadata(self): "client_id": "s6BhdRkqt3", "client_secret": "ZJYCqe3GGRvdrudKyZS0XhGv_Z45DuKhCUk0gBR1vZk", "application_type": "web", - "redirect_uris": - ["https://client.example.org/callback", - "https://client.example.org/callback2"], + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2", + ], "client_name": "SATOSA Test", "logo_uri": "https://client.example.org/logo.png", "subject_type": "pairwise", @@ -229,22 +279,34 @@ def client_metadata(self): } def assert_provider_metadata(self, provider_metadata, client): - assert client.authorization_endpoint == provider_metadata["authorization_endpoint"] + assert ( + client.authorization_endpoint == provider_metadata["authorization_endpoint"] + ) assert client.token_endpoint == provider_metadata["token_endpoint"] - assert client.registration_endpoint == provider_metadata["registration_endpoint"] - assert all(x in client.provider_info.to_dict().items() for x in provider_metadata.items()) + assert ( + client.registration_endpoint == provider_metadata["registration_endpoint"] + ) + assert all( + x in client.provider_info.to_dict().items() + for x in provider_metadata.items() + ) def assert_client_metadata(self, client_metadata, client): assert client.client_id == client_metadata["client_id"] assert client.client_secret == client_metadata["client_secret"] - assert all(x in client.registration_response.to_dict().items() for x in client_metadata.items()) + assert all( + x in client.registration_response.to_dict().items() + for x in client_metadata.items() + ) def test_init(self, provider_metadata, client_metadata): client = _create_client(provider_metadata, client_metadata) assert isinstance(client, oic.oic.Client) assert client.client_authn_method == CLIENT_AUTHN_METHOD - def test_supports_static_provider_discovery(self, provider_metadata, client_metadata): + def test_supports_static_provider_discovery( + self, provider_metadata, client_metadata + ): client = _create_client(provider_metadata, client_metadata) self.assert_provider_metadata(provider_metadata, client) @@ -255,24 +317,30 @@ def test_supports_dynamic_discovery(self, provider_metadata, client_metadata): ISSUER + "/.well-known/openid-configuration", body=json.dumps(provider_metadata), status=200, - content_type='application/json' + content_type="application/json", ) client = _create_client(dict(issuer=ISSUER), client_metadata) self.assert_provider_metadata(provider_metadata, client) - def test_supports_static_client_registration(self, provider_metadata, client_metadata): + def test_supports_static_client_registration( + self, provider_metadata, client_metadata + ): client = _create_client(provider_metadata, client_metadata) self.assert_client_metadata(client_metadata, client) - def test_supports_dynamic_client_registration(self, provider_metadata, client_metadata): + def test_supports_dynamic_client_registration( + self, provider_metadata, client_metadata + ): with responses.RequestsMock(assert_all_requests_are_fired=True) as rsps: rsps.add( responses.POST, provider_metadata["registration_endpoint"], body=json.dumps(client_metadata), status=200, - content_type='application/json' + content_type="application/json", + ) + client = _create_client( + provider_metadata, dict(redirect_uris=client_metadata["redirect_uris"]) ) - client = _create_client(provider_metadata, dict(redirect_uris=client_metadata["redirect_uris"])) self.assert_client_metadata(client_metadata, client) diff --git a/tests/satosa/backends/test_orcid.py b/tests/satosa/backends/test_orcid.py index 5120d4e89..386ee82c1 100644 --- a/tests/satosa/backends/test_orcid.py +++ b/tests/satosa/backends/test_orcid.py @@ -1,4 +1,7 @@ import json +from unittest.mock import Mock +from urllib.parse import parse_qsl, urljoin, urlparse + import pytest import responses @@ -6,14 +9,11 @@ from satosa.context import Context from satosa.internal import InternalData from satosa.response import Response -from unittest.mock import Mock -from urllib.parse import urljoin, urlparse, parse_qsl ORCID_PERSON_ID = "0000-0000-0000-0000" ORCID_PERSON_GIVEN_NAME = "orcid_given_name" ORCID_PERSON_FAMILY_NAME = "orcid_family_name" -ORCID_PERSON_NAME = "{} {}".format( - ORCID_PERSON_GIVEN_NAME, ORCID_PERSON_FAMILY_NAME) +ORCID_PERSON_NAME = "{} {}".format(ORCID_PERSON_GIVEN_NAME, ORCID_PERSON_FAMILY_NAME) ORCID_PERSON_EMAIL = "orcid_email" ORCID_PERSON_COUNTRY = "XX" @@ -28,13 +28,13 @@ def create_backend(self, internal_attributes, backend_config): internal_attributes, backend_config, backend_config["base_url"], - "orcid" + "orcid", ) @pytest.fixture def backend_config(self): return { - "authz_page": 'orcid/auth/callback', + "authz_page": "orcid/auth/callback", "base_url": "https://client.example.com", "client_config": {"client_id": "orcid_client_id"}, "client_secret": "orcid_secret", @@ -43,8 +43,8 @@ def backend_config(self): "server_info": { "authorization_endpoint": "https://orcid.org/oauth/authorize", "token_endpoint": "https://pub.orcid.org/oauth/token", - "user_info": "https://pub.orcid.org/v2.0/" - } + "user_info": "https://pub.orcid.org/v2.0/", + }, } @pytest.fixture @@ -70,18 +70,10 @@ def userinfo(self): }, "emails": { "email": [ - { - "email": ORCID_PERSON_EMAIL, - "verified": True, - "primary": True - } + {"email": ORCID_PERSON_EMAIL, "verified": True, "primary": True} ] }, - "addresses": { - "address": [ - {"country": {"value": ORCID_PERSON_COUNTRY}} - ] - } + "addresses": {"address": [{"country": {"value": ORCID_PERSON_COUNTRY}}]}, } @pytest.fixture @@ -91,14 +83,8 @@ def userinfo_private(self): "given-names": {"value": ORCID_PERSON_GIVEN_NAME}, "family-name": {"value": ORCID_PERSON_FAMILY_NAME}, }, - "emails": { - "email": [ - ] - }, - "addresses": { - "address": [ - ] - } + "emails": {"email": []}, + "addresses": {"address": []}, } def assert_expected_attributes(self, user_claims, actual_attributes): @@ -123,7 +109,7 @@ def setup_token_endpoint(self, token_endpoint_url): "token_type": "bearer", "expires_in": 9999999999999, "name": ORCID_PERSON_NAME, - "orcid": ORCID_PERSON_ID + "orcid": ORCID_PERSON_ID, } responses.add( @@ -131,17 +117,16 @@ def setup_token_endpoint(self, token_endpoint_url): token_endpoint_url, body=json.dumps(token_response), status=200, - content_type="application/json" + content_type="application/json", ) def setup_userinfo_endpoint(self, userinfo_endpoint_url, userinfo): responses.add( responses.GET, - urljoin(userinfo_endpoint_url, - '{}/person'.format(ORCID_PERSON_ID)), + urljoin(userinfo_endpoint_url, "{}/person".format(ORCID_PERSON_ID)), body=json.dumps(userinfo), status=200, - content_type="application/json" + content_type="application/json", ) @pytest.fixture @@ -151,36 +136,35 @@ def incoming_authn_response(self, context, backend_config): context.state[self.orcid_backend.name] = state_data context.request = { "code": "the_orcid_code", - "state": mock_get_state.return_value + "state": mock_get_state.return_value, } return context def test_start_auth(self, context, backend_config): - auth_response = self.orcid_backend.start_auth( - context, None, mock_get_state) + auth_response = self.orcid_backend.start_auth(context, None, mock_get_state) assert isinstance(auth_response, Response) login_url = auth_response.message parsed = urlparse(login_url) assert login_url.startswith( - backend_config["server_info"]["authorization_endpoint"]) + backend_config["server_info"]["authorization_endpoint"] + ) auth_params = dict(parse_qsl(parsed.query)) assert auth_params["scope"] == " ".join(backend_config["scope"]) assert auth_params["response_type"] == backend_config["response_type"] assert auth_params["client_id"] == backend_config["client_config"]["client_id"] assert auth_params["redirect_uri"] == "{}/{}".format( - backend_config["base_url"], - backend_config["authz_page"] + backend_config["base_url"], backend_config["authz_page"] ) assert auth_params["state"] == mock_get_state.return_value @responses.activate def test_authn_response(self, backend_config, userinfo, incoming_authn_response): - self.setup_token_endpoint( - backend_config["server_info"]["token_endpoint"]) + self.setup_token_endpoint(backend_config["server_info"]["token_endpoint"]) self.setup_userinfo_endpoint( - backend_config["server_info"]["user_info"], userinfo) + backend_config["server_info"]["user_info"], userinfo + ) self.orcid_backend._authn_response(incoming_authn_response) @@ -193,14 +177,11 @@ def test_authn_response(self, backend_config, userinfo, incoming_authn_response) @responses.activate def test_user_information(self, context, backend_config, userinfo): self.setup_userinfo_endpoint( - backend_config["server_info"]["user_info"], - userinfo + backend_config["server_info"]["user_info"], userinfo ) user_attributes = self.orcid_backend.user_information( - "orcid_access_token", - ORCID_PERSON_ID, - ORCID_PERSON_NAME + "orcid_access_token", ORCID_PERSON_ID, ORCID_PERSON_NAME ) assert user_attributes["address"] == ORCID_PERSON_COUNTRY @@ -214,14 +195,11 @@ def test_user_information(self, context, backend_config, userinfo): @responses.activate def test_user_information_private(self, context, backend_config, userinfo_private): self.setup_userinfo_endpoint( - backend_config["server_info"]["user_info"], - userinfo_private + backend_config["server_info"]["user_info"], userinfo_private ) user_attributes = self.orcid_backend.user_information( - "orcid_access_token", - ORCID_PERSON_ID, - ORCID_PERSON_NAME + "orcid_access_token", ORCID_PERSON_ID, ORCID_PERSON_NAME ) assert user_attributes["address"] == "" diff --git a/tests/satosa/backends/test_saml2.py b/tests/satosa/backends/test_saml2.py index e1cc96466..c4e6e5c54 100644 --- a/tests/satosa/backends/test_saml2.py +++ b/tests/satosa/backends/test_saml2.py @@ -7,36 +7,36 @@ from collections import Counter from datetime import datetime from unittest.mock import Mock, patch -from urllib.parse import urlparse, parse_qs, parse_qsl +from urllib.parse import parse_qs, parse_qsl, urlparse import pytest - import saml2 -from saml2 import BINDING_HTTP_REDIRECT, BINDING_HTTP_POST +from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT from saml2.authn_context import PASSWORD from saml2.config import IdPConfig, SPConfig from saml2.entity import Entity -from saml2.samlp import authn_request_from_string from saml2.s_utils import deflate_and_base64_encode +from saml2.samlp import authn_request_from_string from satosa.backends.saml2 import SAMLBackend from satosa.context import Context -from satosa.exception import SATOSAAuthenticationError -from satosa.exception import SATOSAMissingStateError +from satosa.exception import SATOSAAuthenticationError, SATOSAMissingStateError from satosa.internal import InternalData from tests.users import USERS -from tests.util import FakeIdP, create_metadata_from_config_dict, FakeSP +from tests.util import FakeIdP, FakeSP, create_metadata_from_config_dict -TEST_RESOURCE_BASE_PATH = os.path.join(os.path.dirname(__file__), "../../test_resources") +TEST_RESOURCE_BASE_PATH = os.path.join( + os.path.dirname(__file__), "../../test_resources" +) INTERNAL_ATTRIBUTES = { - 'attributes': { - 'displayname': {'saml': ['displayName']}, - 'givenname': {'saml': ['givenName']}, - 'mail': {'saml': ['email', 'emailAdress', 'mail']}, - 'edupersontargetedid': {'saml': ['eduPersonTargetedID']}, - 'name': {'saml': ['cn']}, - 'surname': {'saml': ['sn', 'surname']} + "attributes": { + "displayname": {"saml": ["displayName"]}, + "givenname": {"saml": ["givenName"]}, + "mail": {"saml": ["email", "emailAdress", "mail"]}, + "edupersontargetedid": {"saml": ["eduPersonTargetedID"]}, + "name": {"saml": ["cn"]}, + "surname": {"saml": ["sn", "surname"]}, } } @@ -48,27 +48,41 @@ def assert_redirect_to_discovery_server( ): assert redirect_response.status == "303 See Other" parsed = urlparse(redirect_response.message) - redirect_location = "{parsed.scheme}://{parsed.netloc}{parsed.path}".format(parsed=parsed) + redirect_location = "{parsed.scheme}://{parsed.netloc}{parsed.path}".format( + parsed=parsed + ) assert redirect_location == expected_discosrv_url request_params = dict(parse_qsl(parsed.query)) - assert request_params["return"] == sp_conf["service"]["sp"]["endpoints"]["discovery_response"][0][0] + assert ( + request_params["return"] + == sp_conf["service"]["sp"]["endpoints"]["discovery_response"][0][0] + ) assert request_params["entityID"] == sp_conf["entityid"] def assert_redirect_to_idp(redirect_response, idp_conf): assert redirect_response.status == "303 See Other" parsed = urlparse(redirect_response.message) - redirect_location = "{parsed.scheme}://{parsed.netloc}{parsed.path}".format(parsed=parsed) - assert redirect_location == idp_conf["service"]["idp"]["endpoints"]["single_sign_on_service"][0][0] + redirect_location = "{parsed.scheme}://{parsed.netloc}{parsed.path}".format( + parsed=parsed + ) + assert ( + redirect_location + == idp_conf["service"]["idp"]["endpoints"]["single_sign_on_service"][0][0] + ) assert "SAMLRequest" in parse_qs(parsed.query) def assert_authn_response(internal_resp): assert internal_resp.auth_info.auth_class_ref == PASSWORD - expected_data = {'surname': ['Testsson 1'], 'mail': ['test@example.com'], - 'displayname': ['Test Testsson'], 'givenname': ['Test 1'], - 'edupersontargetedid': ['one!for!all']} + expected_data = { + "surname": ["Testsson 1"], + "mail": ["test@example.com"], + "displayname": ["Test Testsson"], + "givenname": ["Test 1"], + "edupersontargetedid": ["one!for!all"], + } assert expected_data == internal_resp.attributes @@ -88,10 +102,13 @@ class TestSAMLBackend: @pytest.fixture(autouse=True) def create_backend(self, sp_conf, idp_conf): setup_test_config(sp_conf, idp_conf) - self.samlbackend = SAMLBackend(Mock(), INTERNAL_ATTRIBUTES, {"sp_config": sp_conf, - "disco_srv": DISCOSRV_URL}, - "base_url", - "samlbackend") + self.samlbackend = SAMLBackend( + Mock(), + INTERNAL_ATTRIBUTES, + {"sp_config": sp_conf, "disco_srv": DISCOSRV_URL}, + "base_url", + "samlbackend", + ) def test_register_endpoints(self, sp_conf): """ @@ -102,20 +119,23 @@ def get_path_from_url(url): return urlparse(url).path.lstrip("/") url_map = self.samlbackend.register_endpoints() - all_sp_endpoints = [get_path_from_url(v[0][0]) for v in sp_conf["service"]["sp"]["endpoints"].values()] + all_sp_endpoints = [ + get_path_from_url(v[0][0]) + for v in sp_conf["service"]["sp"]["endpoints"].values() + ] compiled_regex = [re.compile(regex) for regex, _ in url_map] for endp in all_sp_endpoints: assert any(p.match(endp) for p in compiled_regex) - def test_start_auth_defaults_to_redirecting_to_discovery_server(self, context, sp_conf): + def test_start_auth_defaults_to_redirecting_to_discovery_server( + self, context, sp_conf + ): resp = self.samlbackend.start_auth(context, InternalData()) assert_redirect_to_discovery_server(resp, sp_conf, DISCOSRV_URL) def test_discovery_server_set_in_context(self, context, sp_conf): - discosrv_url = 'https://my.org/saml_discovery_service' - context.decorate( - SAMLBackend.KEY_SAML_DISCOVERY_SERVICE_URL, discosrv_url - ) + discosrv_url = "https://my.org/saml_discovery_service" + context.decorate(SAMLBackend.KEY_SAML_DISCOVERY_SERVICE_URL, discosrv_url) resp = self.samlbackend.start_auth(context, InternalData()) assert_redirect_to_discovery_server(resp, sp_conf, discosrv_url) @@ -150,7 +170,8 @@ def test_full_flow(self, context, idp_conf, sp_conf): req_params["RelayState"], BINDING_HTTP_REDIRECT, "testuser1", - response_binding=response_binding) + response_binding=response_binding, + ) response_context = Context() response_context.request = fake_idp_resp response_context.state = request_context.state @@ -161,18 +182,25 @@ def test_full_flow(self, context, idp_conf, sp_conf): assert context.state[test_state_key] == "my_state" assert_authn_response(internal_resp) - def test_start_auth_redirects_directly_to_mirrored_idp( - self, context, idp_conf): + def test_start_auth_redirects_directly_to_mirrored_idp(self, context, idp_conf): entityid = idp_conf["entityid"] context.decorate(Context.KEY_TARGET_ENTITYID, entityid) resp = self.samlbackend.start_auth(context, InternalData()) assert_redirect_to_idp(resp, idp_conf) - def test_redirect_to_idp_if_only_one_idp_in_metadata(self, context, sp_conf, idp_conf): + def test_redirect_to_idp_if_only_one_idp_in_metadata( + self, context, sp_conf, idp_conf + ): sp_conf["metadata"]["inline"] = [create_metadata_from_config_dict(idp_conf)] # instantiate new backend, without any discovery service configured - samlbackend = SAMLBackend(None, INTERNAL_ATTRIBUTES, {"sp_config": sp_conf}, "base_url", "saml_backend") + samlbackend = SAMLBackend( + None, + INTERNAL_ATTRIBUTES, + {"sp_config": sp_conf}, + "base_url", + "saml_backend", + ) resp = samlbackend.start_auth(context, InternalData()) assert_redirect_to_idp(resp, idp_conf) @@ -181,7 +209,10 @@ def test_authn_request(self, context, idp_conf): resp = self.samlbackend.authn_request(context, idp_conf["entityid"]) assert_redirect_to_idp(resp, idp_conf) req_params = dict(parse_qsl(urlparse(resp.message).query)) - assert context.state[self.samlbackend.name]["relay_state"] == req_params["RelayState"] + assert ( + context.state[self.samlbackend.name]["relay_state"] + == req_params["RelayState"] + ) @pytest.mark.parametrize("hostname", ["example.com:8443", "example.net"]) @pytest.mark.parametrize( @@ -247,7 +278,9 @@ def test_authn_response(self, context, idp_conf, sp_conf): idp_conf, sp_conf, response_binding ) context.request = auth_resp - context.state[self.samlbackend.name] = {"relay_state": request_params["RelayState"]} + context.state[self.samlbackend.name] = { + "relay_state": request_params["RelayState"] + } self.samlbackend.authn_response(context, response_binding) context, internal_resp = self.samlbackend.auth_callback_func.call_args[0] @@ -300,8 +333,8 @@ def test_no_relay_state_raises_error(self, context, idp_conf, sp_conf): self.samlbackend.authn_response(context, response_binding) @pytest.mark.skipif( - saml2.__version__ < '4.6.1', - reason="Optional NameID needs pysaml2 v4.6.1 or higher" + saml2.__version__ < "4.6.1", + reason="Optional NameID needs pysaml2 v4.6.1 or higher", ) def test_authn_response_no_name_id(self, context, idp_conf, sp_conf): response_binding = BINDING_HTTP_REDIRECT @@ -322,13 +355,17 @@ def test_authn_response_no_name_id(self, context, idp_conf, sp_conf): assert_authn_response(internal_resp) def test_authn_response_with_encrypted_assertion(self, sp_conf, context): - with open(os.path.join( - TEST_RESOURCE_BASE_PATH, - "idp_metadata_for_encrypted_signed_auth_response.xml" - )) as idp_metadata_file: + with open( + os.path.join( + TEST_RESOURCE_BASE_PATH, + "idp_metadata_for_encrypted_signed_auth_response.xml", + ) + ) as idp_metadata_file: sp_conf["metadata"]["inline"] = [idp_metadata_file.read()] - sp_conf["entityid"] = "https://federation-dev-1.scienceforum.sc/Saml2/proxy_saml2_backend.xml" + sp_conf[ + "entityid" + ] = "https://federation-dev-1.scienceforum.sc/Saml2/proxy_saml2_backend.xml" samlbackend = SAMLBackend( Mock(), INTERNAL_ATTRIBUTES, @@ -339,13 +376,18 @@ def test_authn_response_with_encrypted_assertion(self, sp_conf, context): response_binding = BINDING_HTTP_REDIRECT relay_state = "test relay state" - with open(os.path.join( - TEST_RESOURCE_BASE_PATH, - "auth_response_with_encrypted_signed_assertion.xml" - )) as auth_response_file: + with open( + os.path.join( + TEST_RESOURCE_BASE_PATH, + "auth_response_with_encrypted_signed_assertion.xml", + ) + ) as auth_response_file: auth_response = auth_response_file.read() - context.request = {"SAMLResponse": deflate_and_base64_encode(auth_response), "RelayState": relay_state} + context.request = { + "SAMLResponse": deflate_and_base64_encode(auth_response), + "RelayState": relay_state, + } context.state[self.samlbackend.name] = {"relay_state": relay_state} with open( @@ -354,8 +396,9 @@ def test_authn_response_with_encrypted_assertion(self, sp_conf, context): samlbackend.encryption_keys = [encryption_key_file.read()] assertion_issued_at = 1479315212 - with patch('saml2.validate.time_util.shift_time') as mock_shift_time, \ - patch('saml2.validate.time_util.utc_now') as mock_utc_now: + with patch("saml2.validate.time_util.shift_time") as mock_shift_time, patch( + "saml2.validate.time_util.utc_now" + ) as mock_utc_now: mock_utc_now.return_value = assertion_issued_at + 1 mock_shift_time.side_effect = [ datetime.utcfromtimestamp(assertion_issued_at + 1), @@ -364,21 +407,35 @@ def test_authn_response_with_encrypted_assertion(self, sp_conf, context): samlbackend.authn_response(context, response_binding) context, internal_resp = samlbackend.auth_callback_func.call_args[0] - assert Counter(internal_resp.attributes.keys()) == Counter({"mail", "givenname", "displayname", "surname"}) + assert Counter(internal_resp.attributes.keys()) == Counter( + {"mail", "givenname", "displayname", "surname"} + ) def test_backend_reads_encryption_key_from_key_file(self, sp_conf): - sp_conf["key_file"] = os.path.join(TEST_RESOURCE_BASE_PATH, "encryption_key.pem") - samlbackend = SAMLBackend(Mock(), INTERNAL_ATTRIBUTES, {"sp_config": sp_conf, - "disco_srv": DISCOSRV_URL}, - "base_url", "samlbackend") + sp_conf["key_file"] = os.path.join( + TEST_RESOURCE_BASE_PATH, "encryption_key.pem" + ) + samlbackend = SAMLBackend( + Mock(), + INTERNAL_ATTRIBUTES, + {"sp_config": sp_conf, "disco_srv": DISCOSRV_URL}, + "base_url", + "samlbackend", + ) assert samlbackend.encryption_keys def test_backend_reads_encryption_key_from_encryption_keypair(self, sp_conf): del sp_conf["key_file"] - sp_conf["encryption_keypairs"] = [{"key_file": os.path.join(TEST_RESOURCE_BASE_PATH, "encryption_key.pem")}] - samlbackend = SAMLBackend(Mock(), INTERNAL_ATTRIBUTES, {"sp_config": sp_conf, - "disco_srv": DISCOSRV_URL}, - "base_url", "samlbackend") + sp_conf["encryption_keypairs"] = [ + {"key_file": os.path.join(TEST_RESOURCE_BASE_PATH, "encryption_key.pem")} + ] + samlbackend = SAMLBackend( + Mock(), + INTERNAL_ATTRIBUTES, + {"sp_config": sp_conf, "disco_srv": DISCOSRV_URL}, + "base_url", + "samlbackend", + ) assert samlbackend.encryption_keys def test_metadata_endpoint(self, context, sp_conf): @@ -390,19 +447,33 @@ def test_metadata_endpoint(self, context, sp_conf): def test_get_metadata_desc(self, sp_conf, idp_conf): sp_conf["metadata"]["inline"] = [create_metadata_from_config_dict(idp_conf)] # instantiate new backend, with a single backing IdP - samlbackend = SAMLBackend(None, INTERNAL_ATTRIBUTES, {"sp_config": sp_conf}, "base_url", "saml_backend") + samlbackend = SAMLBackend( + None, + INTERNAL_ATTRIBUTES, + {"sp_config": sp_conf}, + "base_url", + "saml_backend", + ) entity_descriptions = samlbackend.get_metadata_desc() assert len(entity_descriptions) == 1 idp_desc = entity_descriptions[0].to_dict() - assert idp_desc["entityid"] == urlsafe_b64encode(idp_conf["entityid"].encode("utf-8")).decode("utf-8") + assert idp_desc["entityid"] == urlsafe_b64encode( + idp_conf["entityid"].encode("utf-8") + ).decode("utf-8") assert idp_desc["contact_person"] == idp_conf["contact_person"] - assert idp_desc["organization"]["name"][0] == tuple(idp_conf["organization"]["name"][0]) - assert idp_desc["organization"]["display_name"][0] == tuple(idp_conf["organization"]["display_name"][0]) - assert idp_desc["organization"]["url"][0] == tuple(idp_conf["organization"]["url"][0]) + assert idp_desc["organization"]["name"][0] == tuple( + idp_conf["organization"]["name"][0] + ) + assert idp_desc["organization"]["display_name"][0] == tuple( + idp_conf["organization"]["display_name"][0] + ) + assert idp_desc["organization"]["url"][0] == tuple( + idp_conf["organization"]["url"][0] + ) expected_ui_info = idp_conf["service"]["idp"]["ui_info"] ui_info = idp_desc["service"]["idp"]["ui_info"] @@ -412,24 +483,43 @@ def test_get_metadata_desc(self, sp_conf, idp_conf): def test_get_metadata_desc_with_logo_without_lang(self, sp_conf, idp_conf): # add logo without 'lang' - idp_conf["service"]["idp"]["ui_info"]["logo"] = [{"text": "https://idp.example.com/static/logo.png", - "width": "120", "height": "60"}] + idp_conf["service"]["idp"]["ui_info"]["logo"] = [ + { + "text": "https://idp.example.com/static/logo.png", + "width": "120", + "height": "60", + } + ] sp_conf["metadata"]["inline"] = [create_metadata_from_config_dict(idp_conf)] # instantiate new backend, with a single backing IdP - samlbackend = SAMLBackend(None, INTERNAL_ATTRIBUTES, {"sp_config": sp_conf}, "base_url", "saml_backend") + samlbackend = SAMLBackend( + None, + INTERNAL_ATTRIBUTES, + {"sp_config": sp_conf}, + "base_url", + "saml_backend", + ) entity_descriptions = samlbackend.get_metadata_desc() assert len(entity_descriptions) == 1 idp_desc = entity_descriptions[0].to_dict() - assert idp_desc["entityid"] == urlsafe_b64encode(idp_conf["entityid"].encode("utf-8")).decode("utf-8") + assert idp_desc["entityid"] == urlsafe_b64encode( + idp_conf["entityid"].encode("utf-8") + ).decode("utf-8") assert idp_desc["contact_person"] == idp_conf["contact_person"] - assert idp_desc["organization"]["name"][0] == tuple(idp_conf["organization"]["name"][0]) - assert idp_desc["organization"]["display_name"][0] == tuple(idp_conf["organization"]["display_name"][0]) - assert idp_desc["organization"]["url"][0] == tuple(idp_conf["organization"]["url"][0]) + assert idp_desc["organization"]["name"][0] == tuple( + idp_conf["organization"]["name"][0] + ) + assert idp_desc["organization"]["display_name"][0] == tuple( + idp_conf["organization"]["display_name"][0] + ) + assert idp_desc["organization"]["url"][0] == tuple( + idp_conf["organization"]["url"][0] + ) expected_ui_info = idp_conf["service"]["idp"]["ui_info"] ui_info = idp_desc["service"]["idp"]["ui_info"] @@ -445,8 +535,16 @@ def test_default_redirect_to_discovery_service_if_using_mdq( # one IdP in the metadata, but MDQ also configured so should always redirect to the discovery service sp_conf["metadata"]["inline"] = [create_metadata_from_config_dict(idp_conf)] sp_conf["metadata"]["mdq"] = ["https://mdq.example.com"] - samlbackend = SAMLBackend(None, INTERNAL_ATTRIBUTES, {"sp_config": sp_conf, "disco_srv": DISCOSRV_URL,}, - "base_url", "saml_backend") + samlbackend = SAMLBackend( + None, + INTERNAL_ATTRIBUTES, + { + "sp_config": sp_conf, + "disco_srv": DISCOSRV_URL, + }, + "base_url", + "saml_backend", + ) resp = samlbackend.start_auth(context, InternalData()) assert_redirect_to_discovery_server(resp, sp_conf, DISCOSRV_URL) diff --git a/tests/satosa/frontends/test_openid_connect.py b/tests/satosa/frontends/test_openid_connect.py index f769b2c66..f623801d1 100644 --- a/tests/satosa/frontends/test_openid_connect.py +++ b/tests/satosa/frontends/test_openid_connect.py @@ -6,28 +6,35 @@ from base64 import urlsafe_b64encode from collections import Counter from unittest.mock import Mock -from urllib.parse import urlparse, parse_qsl +from urllib.parse import parse_qsl, urlparse import pytest -from oic.oic.message import AuthorizationResponse, AuthorizationRequest, IdToken, ClaimsRequest, \ - Claims, AuthorizationErrorResponse, RegistrationResponse, RegistrationRequest, \ - ClientRegistrationErrorResponse, ProviderConfigurationResponse, AccessTokenRequest, AccessTokenResponse, \ - TokenErrorResponse, OpenIDSchema -from oic.oic.provider import TokenEndpoint, UserinfoEndpoint, RegistrationEndpoint +from oic.oic.message import ( + AccessTokenRequest, + AccessTokenResponse, + AuthorizationErrorResponse, + AuthorizationRequest, + AuthorizationResponse, + Claims, + ClaimsRequest, + ClientRegistrationErrorResponse, + IdToken, + OpenIDSchema, + ProviderConfigurationResponse, + RegistrationRequest, + RegistrationResponse, + TokenErrorResponse, +) +from oic.oic.provider import RegistrationEndpoint, TokenEndpoint, UserinfoEndpoint from saml2.authn_context import PASSWORD from satosa.attribute_mapping import AttributeMapper from satosa.exception import SATOSAAuthenticationError from satosa.frontends.openid_connect import OpenIDConnectFrontend -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData -from tests.users import USERS -from tests.users import OIDC_USERS +from satosa.internal import AuthenticationInformation, InternalData +from tests.users import OIDC_USERS, USERS - -INTERNAL_ATTRIBUTES = { - "attributes": {"mail": {"saml": ["email"], "openid": ["email"]}} -} +INTERNAL_ATTRIBUTES = {"attributes": {"mail": {"saml": ["email"], "openid": ["email"]}}} BASE_URL = "https://op.example.com" CLIENT_ID = "client1" CLIENT_SECRET = "client_secret" @@ -45,6 +52,7 @@ "eduperson": ["eduperson_scoped_affiliation", "eduperson_principal_name"] } + class TestOpenIDConnectFrontend(object): @pytest.fixture def frontend_config(self, signing_key_path): @@ -52,8 +60,8 @@ def frontend_config(self, signing_key_path): "signing_key_path": signing_key_path, "provider": { "response_types_supported": ["code", "id_token", "code id_token token"], - "scopes_supported": ["openid", "email"] - } + "scopes_supported": ["openid", "email"], + }, } return config @@ -80,7 +88,7 @@ def frontend_config_with_extra_id_token_claims(self, signing_key_path): "scopes_supported": ["openid", "email"], "extra_id_token_claims": { CLIENT_ID: ["email"], - } + }, }, } @@ -88,8 +96,13 @@ def frontend_config_with_extra_id_token_claims(self, signing_key_path): def create_frontend(self, frontend_config): # will use in-memory storage - instance = OpenIDConnectFrontend(lambda ctx, req: None, INTERNAL_ATTRIBUTES, - frontend_config, BASE_URL, "oidc_frontend") + instance = OpenIDConnectFrontend( + lambda ctx, req: None, + INTERNAL_ATTRIBUTES, + frontend_config, + BASE_URL, + "oidc_frontend", + ) instance.register_endpoints(["foo_backend"]) return instance @@ -121,9 +134,15 @@ def authn_req(self): nonce = "nonce" redirect_uri = "https://client.example.com" claims_req = ClaimsRequest(id_token=Claims(email=None)) - req = AuthorizationRequest(client_id=CLIENT_ID, state=state, scope="openid", - response_type="id_token", redirect_uri=redirect_uri, - nonce=nonce, claims=claims_req) + req = AuthorizationRequest( + client_id=CLIENT_ID, + state=state, + scope="openid", + response_type="id_token", + redirect_uri=redirect_uri, + nonce=nonce, + claims=claims_req, + ) return req @pytest.fixture @@ -145,9 +164,12 @@ def authn_req_with_extra_scopes(self): def insert_client_in_client_db(self, frontend, redirect_uri, extra_metadata={}): frontend.provider.clients = { - CLIENT_ID: {"response_types": ["code", "id_token"], - "redirect_uris": [redirect_uri], - "client_secret": CLIENT_SECRET}} + CLIENT_ID: { + "response_types": ["code", "id_token"], + "redirect_uris": [redirect_uri], + "client_secret": CLIENT_SECRET, + } + } frontend.provider.clients[CLIENT_ID].update(extra_metadata) def insert_user_in_user_db(self, frontend, user_id): @@ -159,7 +181,9 @@ def insert_user_in_user_db(self, frontend, user_id): ) def create_access_token(self, frontend, user_id, auth_req): - sub = frontend.provider.authz_state.get_subject_identifier('pairwise', user_id, 'client1.example.com') + sub = frontend.provider.authz_state.get_subject_identifier( + "pairwise", user_id, "client1.example.com" + ) auth_req = AuthorizationRequest().from_dict(auth_req) access_token = frontend.provider.authz_state.create_access_token(auth_req, sub) return access_token.value @@ -198,7 +222,9 @@ def test_handle_authn_request(self, context, frontend, authn_req): mock_callback = Mock() frontend.auth_req_callback_func = mock_callback client_name = "test client" - self.insert_client_in_client_db(frontend, authn_req["redirect_uri"], {"client_name": client_name}) + self.insert_client_in_client_db( + frontend, authn_req["redirect_uri"], {"client_name": client_name} + ) context.request = dict(parse_qsl(authn_req.to_urlencoded())) frontend.handle_authn_request(context) @@ -207,7 +233,7 @@ def test_handle_authn_request(self, context, frontend, authn_req): context, internal_req = mock_callback.call_args[0] assert internal_req.requester == authn_req["client_id"] assert internal_req.requester_name == [{"lang": "en", "text": client_name}] - assert internal_req.subject_type == 'pairwise' + assert internal_req.subject_type == "pairwise" assert internal_req.attributes == ["mail"] def test_handle_authn_request_with_extra_scopes( @@ -233,17 +259,32 @@ def test_handle_authn_request_with_extra_scopes( ] def test_get_approved_attributes(self, frontend): - claims_req = ClaimsRequest(id_token=Claims(email=None), userinfo=Claims(userinfo_claim=None)) + claims_req = ClaimsRequest( + id_token=Claims(email=None), userinfo=Claims(userinfo_claim=None) + ) req = AuthorizationRequest(scope="openid profile", claims=claims_req) - provider_supported_claims = ["email", "name", "given_name", "family_name", "userinfo_claim", "extra_claim"] + provider_supported_claims = [ + "email", + "name", + "given_name", + "family_name", + "userinfo_claim", + "extra_claim", + ] result = frontend._get_approved_attributes(provider_supported_claims, req) - assert Counter(result) == Counter(["email", "name", "given_name", "family_name", "userinfo_claim"]) + assert Counter(result) == Counter( + ["email", "name", "given_name", "family_name", "userinfo_claim"] + ) def test_handle_backend_error(self, context, frontend): redirect_uri = "https://client.example.com" - areq = AuthorizationRequest(client_id=CLIENT_ID, scope="openid", response_type="id_token", - redirect_uri=redirect_uri) + areq = AuthorizationRequest( + client_id=CLIENT_ID, + scope="openid", + response_type="id_token", + redirect_uri=redirect_uri, + ) context.state[frontend.name] = {"oidc_request": areq.to_urlencoded()} # fake an error @@ -252,32 +293,39 @@ def test_handle_backend_error(self, context, frontend): resp = frontend.handle_backend_error(error) assert resp.message.startswith(redirect_uri) - error_response = AuthorizationErrorResponse().deserialize(urlparse(resp.message).fragment) + error_response = AuthorizationErrorResponse().deserialize( + urlparse(resp.message).fragment + ) error_response["error"] = "access_denied" error_response["error_description"] == message def test_register_client(self, context, frontend): redirect_uri = "https://client.example.com" - registration_request = RegistrationRequest(redirect_uris=[redirect_uri], - response_types=["id_token"]) + registration_request = RegistrationRequest( + redirect_uris=[redirect_uri], response_types=["id_token"] + ) context.request = registration_request.to_dict() registration_response = frontend.client_registration(context) assert registration_response.status == "201 Created" - reg_resp = RegistrationResponse().deserialize(registration_response.message, "json") + reg_resp = RegistrationResponse().deserialize( + registration_response.message, "json" + ) assert "client_id" in reg_resp assert reg_resp["redirect_uris"] == [redirect_uri] assert reg_resp["response_types"] == ["id_token"] def test_register_client_with_wrong_response_type(self, context, frontend): redirect_uri = "https://client.example.com" - registration_request = RegistrationRequest(redirect_uris=[redirect_uri], - response_types=["id_token token"]) + registration_request = RegistrationRequest( + redirect_uris=[redirect_uri], response_types=["id_token token"] + ) context.request = registration_request.to_dict() registration_response = frontend.client_registration(context) assert registration_response.status == "400 Bad Request" error_response = ClientRegistrationErrorResponse().deserialize( - registration_response.message, "json") + registration_response.message, "json" + ) assert error_response["error"] == "invalid_request" assert "response_type" in error_response["error_description"] @@ -285,7 +333,9 @@ def test_provider_configuration_endpoint(self, context, frontend): expected_capabilities = { "response_types_supported": ["code", "id_token", "code id_token token"], "jwks_uri": "{}/{}/jwks".format(BASE_URL, frontend.name), - "authorization_endpoint": "{}/foo_backend/{}/authorization".format(BASE_URL, frontend.name), + "authorization_endpoint": "{}/foo_backend/{}/authorization".format( + BASE_URL, frontend.name + ), "token_endpoint": "{}/{}/token".format(BASE_URL, frontend.name), "userinfo_endpoint": "{}/{}/userinfo".format(BASE_URL, frontend.name), "id_token_signing_alg_values_supported": ["RS256"], @@ -300,11 +350,13 @@ def test_provider_configuration_endpoint(self, context, frontend): "issuer": BASE_URL, "require_request_uri_registration": False, "token_endpoint_auth_methods_supported": ["client_secret_basic"], - "version": "3.0" + "version": "3.0", } http_response = frontend.provider_config(context) - provider_config = ProviderConfigurationResponse().deserialize(http_response.message, "json") + provider_config = ProviderConfigurationResponse().deserialize( + http_response.message, "json" + ) provider_config_dict = provider_config.to_dict() scopes_supported = provider_config_dict.pop("scopes_supported") @@ -369,44 +421,70 @@ def test_jwks(self, context, frontend): jwks = json.loads(http_response.message) assert jwks == {"keys": [frontend.signing_key.serialize()]} - def test_register_endpoints_token_and_userinfo_endpoint_is_published_if_necessary(self, frontend): + def test_register_endpoints_token_and_userinfo_endpoint_is_published_if_necessary( + self, frontend + ): urls = frontend.register_endpoints(["test"]) - assert ("^{}/{}".format(frontend.name, TokenEndpoint.url), frontend.token_endpoint) in urls - assert ("^{}/{}".format(frontend.name, UserinfoEndpoint.url), frontend.userinfo_endpoint) in urls + assert ( + "^{}/{}".format(frontend.name, TokenEndpoint.url), + frontend.token_endpoint, + ) in urls + assert ( + "^{}/{}".format(frontend.name, UserinfoEndpoint.url), + frontend.userinfo_endpoint, + ) in urls def test_register_endpoints_token_and_userinfo_endpoint_is_not_published_if_only_implicit_flow( - self, frontend_config, context): - frontend_config["provider"]["response_types_supported"] = ["id_token", "id_token token"] + self, frontend_config, context + ): + frontend_config["provider"]["response_types_supported"] = [ + "id_token", + "id_token token", + ] frontend = self.create_frontend(frontend_config) urls = frontend.register_endpoints(["test"]) - assert ("^{}/{}".format("test", TokenEndpoint.url), frontend.token_endpoint) not in urls - assert ("^{}/{}".format("test", UserinfoEndpoint.url), frontend.userinfo_endpoint) not in urls + assert ( + "^{}/{}".format("test", TokenEndpoint.url), + frontend.token_endpoint, + ) not in urls + assert ( + "^{}/{}".format("test", UserinfoEndpoint.url), + frontend.userinfo_endpoint, + ) not in urls http_response = frontend.provider_config(context) - provider_config = ProviderConfigurationResponse().deserialize(http_response.message, "json") + provider_config = ProviderConfigurationResponse().deserialize( + http_response.message, "json" + ) assert "token_endpoint" not in provider_config - @pytest.mark.parametrize("client_registration_enabled", [ - True, - False - ]) + @pytest.mark.parametrize("client_registration_enabled", [True, False]) def test_register_endpoints_dynamic_client_registration_is_configurable( - self, frontend_config, client_registration_enabled): - frontend_config["provider"]["client_registration_supported"] = client_registration_enabled + self, frontend_config, client_registration_enabled + ): + frontend_config["provider"][ + "client_registration_supported" + ] = client_registration_enabled frontend = self.create_frontend(frontend_config) urls = frontend.register_endpoints(["test"]) - assert (("^{}/{}".format(frontend.name, RegistrationEndpoint.url), - frontend.client_registration) in urls) == client_registration_enabled - provider_info = ProviderConfigurationResponse().deserialize(frontend.provider_config(None).message, "json") + assert ( + ( + "^{}/{}".format(frontend.name, RegistrationEndpoint.url), + frontend.client_registration, + ) + in urls + ) == client_registration_enabled + provider_info = ProviderConfigurationResponse().deserialize( + frontend.provider_config(None).message, "json" + ) assert ("registration_endpoint" in provider_info) == client_registration_enabled - @pytest.mark.parametrize("sub_mirror_public", [ - True, - False - ]) - def test_mirrored_subject(self, context, frontend_config, authn_req, sub_mirror_public): + @pytest.mark.parametrize("sub_mirror_public", [True, False]) + def test_mirrored_subject( + self, context, frontend_config, authn_req, sub_mirror_public + ): frontend_config["sub_mirror_public"] = sub_mirror_public frontend_config["provider"]["subject_types_supported"] = ["public"] frontend = self.create_frontend(frontend_config) @@ -433,7 +511,9 @@ def test_token_endpoint(self, context, frontend_config, authn_req): authn_req["response_type"] = "code" authn_resp = frontend.provider.authorize(authn_req, user_id) - context.request = AccessTokenRequest(redirect_uri=authn_req["redirect_uri"], code=authn_resp["code"]).to_dict() + context.request = AccessTokenRequest( + redirect_uri=authn_req["redirect_uri"], code=authn_resp["code"] + ).to_dict() credentials = "{}:{}".format(CLIENT_ID, CLIENT_SECRET) basic_auth = urlsafe_b64encode(credentials.encode("utf-8")).decode("utf-8") context.request_authorization = "Basic {}".format(basic_auth) @@ -444,7 +524,9 @@ def test_token_endpoint(self, context, frontend_config, authn_req): assert parsed["expires_in"] == token_lifetime assert parsed["id_token"] - def test_token_endpoint_with_extra_claims(self, context, frontend_config_with_extra_id_token_claims, authn_req): + def test_token_endpoint_with_extra_claims( + self, context, frontend_config_with_extra_id_token_claims, authn_req + ): frontend = self.create_frontend(frontend_config_with_extra_id_token_claims) user_id = "test_user" @@ -453,7 +535,9 @@ def test_token_endpoint_with_extra_claims(self, context, frontend_config_with_ex authn_req["response_type"] = "code" authn_resp = frontend.provider.authorize(authn_req, user_id) - context.request = AccessTokenRequest(redirect_uri=authn_req["redirect_uri"], code=authn_resp["code"]).to_dict() + context.request = AccessTokenRequest( + redirect_uri=authn_req["redirect_uri"], code=authn_resp["code"] + ).to_dict() credentials = "{}:{}".format(CLIENT_ID, CLIENT_SECRET) basic_auth = urlsafe_b64encode(credentials.encode("utf-8")).decode("utf-8") context.request_authorization = "Basic {}".format(basic_auth) @@ -465,10 +549,17 @@ def test_token_endpoint_with_extra_claims(self, context, frontend_config_with_ex id_token = IdToken().from_jwt(parsed["id_token"], key=[frontend.signing_key]) assert id_token["email"] == "test@example.com" - def test_token_endpoint_issues_refresh_tokens_if_configured(self, context, frontend_config, authn_req): + def test_token_endpoint_issues_refresh_tokens_if_configured( + self, context, frontend_config, authn_req + ): frontend_config["provider"]["refresh_token_lifetime"] = 60 * 60 * 24 * 365 - frontend = OpenIDConnectFrontend(lambda ctx, req: None, INTERNAL_ATTRIBUTES, - frontend_config, BASE_URL, "oidc_frontend") + frontend = OpenIDConnectFrontend( + lambda ctx, req: None, + INTERNAL_ATTRIBUTES, + frontend_config, + BASE_URL, + "oidc_frontend", + ) frontend.register_endpoints(["test_backend"]) user_id = "test_user" @@ -477,7 +568,9 @@ def test_token_endpoint_issues_refresh_tokens_if_configured(self, context, front authn_req["response_type"] = "code" authn_resp = frontend.provider.authorize(authn_req, user_id) - context.request = AccessTokenRequest(redirect_uri=authn_req["redirect_uri"], code=authn_resp["code"]).to_dict() + context.request = AccessTokenRequest( + redirect_uri=authn_req["redirect_uri"], code=authn_resp["code"] + ).to_dict() credentials = "{}:{}".format(CLIENT_ID, CLIENT_SECRET) basic_auth = urlsafe_b64encode(credentials.encode("utf-8")).decode("utf-8") context.request_authorization = "Basic {}".format(basic_auth) @@ -486,8 +579,12 @@ def test_token_endpoint_issues_refresh_tokens_if_configured(self, context, front parsed = AccessTokenResponse().deserialize(response.message, "json") assert parsed["refresh_token"] - def test_token_endpoint_with_invalid_client_authentication(self, context, frontend, authn_req): - context.request = AccessTokenRequest(redirect_uri=authn_req["redirect_uri"], code="code").to_dict() + def test_token_endpoint_with_invalid_client_authentication( + self, context, frontend, authn_req + ): + context.request = AccessTokenRequest( + redirect_uri=authn_req["redirect_uri"], code="code" + ).to_dict() credentials = "{}:{}".format("unknown", "unknown") basic_auth = urlsafe_b64encode(credentials.encode("utf-8")).decode("utf-8") context.request_authorization = "Basic {}".format(basic_auth) @@ -499,7 +596,9 @@ def test_token_endpoint_with_invalid_client_authentication(self, context, fronte def test_token_endpoint_with_invalid_code(self, context, frontend, authn_req): self.insert_client_in_client_db(frontend, authn_req["redirect_uri"]) - context.request = AccessTokenRequest(redirect_uri=authn_req["redirect_uri"], code="invalid").to_dict() + context.request = AccessTokenRequest( + redirect_uri=authn_req["redirect_uri"], code="invalid" + ).to_dict() credentials = "{}:{}".format(CLIENT_ID, CLIENT_SECRET) basic_auth = urlsafe_b64encode(credentials.encode("utf-8")).decode("utf-8") context.request_authorization = "Basic {}".format(basic_auth) @@ -560,10 +659,14 @@ def test_full_flow(self, context, frontend_with_extra_scopes): _ = ProviderConfigurationResponse().deserialize(http_response.message, "json") # client registration - registration_request = RegistrationRequest(redirect_uris=[redirect_uri], response_types=[response_type]) + registration_request = RegistrationRequest( + redirect_uris=[redirect_uri], response_types=[response_type] + ) context.request = registration_request.to_dict() http_response = frontend_with_extra_scopes.client_registration(context) - registration_response = RegistrationResponse().deserialize(http_response.message, "json") + registration_response = RegistrationResponse().deserialize( + http_response.message, "json" + ) # authentication request authn_req = AuthorizationRequest( @@ -585,14 +688,20 @@ def test_full_flow(self, context, frontend_with_extra_scopes): http_response = frontend_with_extra_scopes.handle_authn_response( context, internal_response ) - authn_resp = AuthorizationResponse().deserialize(urlparse(http_response.message).fragment, "urlencoded") + authn_resp = AuthorizationResponse().deserialize( + urlparse(http_response.message).fragment, "urlencoded" + ) assert "code" in authn_resp assert "access_token" in authn_resp assert "id_token" in authn_resp # token request - context.request = AccessTokenRequest(redirect_uri=authn_req["redirect_uri"], code=authn_resp["code"]).to_dict() - credentials = "{}:{}".format(registration_response["client_id"], registration_response["client_secret"]) + context.request = AccessTokenRequest( + redirect_uri=authn_req["redirect_uri"], code=authn_resp["code"] + ).to_dict() + credentials = "{}:{}".format( + registration_response["client_id"], registration_response["client_secret"] + ) basic_auth = urlsafe_b64encode(credentials.encode("utf-8")).decode("utf-8") context.request_authorization = "Basic {}".format(basic_auth) diff --git a/tests/satosa/frontends/test_saml2.py b/tests/satosa/frontends/test_saml2.py index 978489429..25fc2621d 100644 --- a/tests/satosa/frontends/test_saml2.py +++ b/tests/satosa/frontends/test_saml2.py @@ -5,29 +5,40 @@ import itertools import re from collections import Counter -from urllib.parse import urlparse, parse_qs +from urllib.parse import parse_qs, urlparse import pytest -from saml2 import BINDING_HTTP_REDIRECT, BINDING_HTTP_POST +from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT from saml2.authn_context import PASSWORD from saml2.config import SPConfig -from saml2.entity_category import refeds, swamid, edugain +from saml2.entity_category import edugain, refeds, swamid from saml2.entity_category.edugain import COCO from saml2.entity_category.refeds import RESEARCH_AND_SCHOLARSHIP -from saml2.entity_category.swamid import SFS_1993_1153, RESEARCH_AND_EDUCATION, EU, HEI, NREN -from saml2.saml import NAMEID_FORMAT_TRANSIENT -from saml2.saml import NAMEID_FORMAT_PERSISTENT -from saml2.saml import NAMEID_FORMAT_EMAILADDRESS -from saml2.saml import NAMEID_FORMAT_UNSPECIFIED -from saml2.saml import NameID, Subject +from saml2.entity_category.swamid import ( + EU, + HEI, + NREN, + RESEARCH_AND_EDUCATION, + SFS_1993_1153, +) +from saml2.saml import ( + NAMEID_FORMAT_EMAILADDRESS, + NAMEID_FORMAT_PERSISTENT, + NAMEID_FORMAT_TRANSIENT, + NAMEID_FORMAT_UNSPECIFIED, + NameID, + Subject, +) from saml2.samlp import NameIDPolicy from satosa.attribute_mapping import AttributeMapper -from satosa.frontends.saml2 import SAMLFrontend, SAMLMirrorFrontend -from satosa.frontends.saml2 import SAMLVirtualCoFrontend -from satosa.frontends.saml2 import subject_type_to_saml_nameid_format -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.frontends.saml2 import ( + SAMLFrontend, + SAMLMirrorFrontend, + SAMLVirtualCoFrontend, + subject_type_to_saml_nameid_format, +) +from satosa.internal import AuthenticationInformation, InternalData from satosa.state import State from tests.users import USERS from tests.util import FakeSP, create_metadata_from_config_dict @@ -39,37 +50,58 @@ "mail": {"saml": ["email", "emailAdress", "mail"]}, "edupersontargetedid": {"saml": ["eduPersonTargetedID"]}, "name": {"saml": ["cn"]}, - "surname": {"saml": ["sn", "surname"]} + "surname": {"saml": ["sn", "surname"]}, } } -ENDPOINTS = {"single_sign_on_service": {BINDING_HTTP_REDIRECT: "sso/redirect", - BINDING_HTTP_POST: "sso/post"}} +ENDPOINTS = { + "single_sign_on_service": { + BINDING_HTTP_REDIRECT: "sso/redirect", + BINDING_HTTP_POST: "sso/post", + } +} BASE_URL = "https://satosa-idp.example.com" class TestSAMLFrontend: @pytest.fixture def internal_response(self, idp_conf): - auth_info = AuthenticationInformation(PASSWORD, "2015-09-30T12:21:37Z", idp_conf["entityid"]) + auth_info = AuthenticationInformation( + PASSWORD, "2015-09-30T12:21:37Z", idp_conf["entityid"] + ) internal_response = InternalData(auth_info=auth_info) - internal_response.attributes = AttributeMapper(INTERNAL_ATTRIBUTES).to_internal("saml", USERS["testuser1"]) + internal_response.attributes = AttributeMapper(INTERNAL_ATTRIBUTES).to_internal( + "saml", USERS["testuser1"] + ) return internal_response def construct_base_url_from_entity_id(self, entity_id): return "{parsed.scheme}://{parsed.netloc}".format(parsed=urlparse(entity_id)) - def setup_for_authn_req(self, context, idp_conf, sp_conf, nameid_format=None, relay_state="relay_state", - internal_attributes=INTERNAL_ATTRIBUTES, extra_config={}, - subject=None): + def setup_for_authn_req( + self, + context, + idp_conf, + sp_conf, + nameid_format=None, + relay_state="relay_state", + internal_attributes=INTERNAL_ATTRIBUTES, + extra_config={}, + subject=None, + ): config = {"idp_config": idp_conf, "endpoints": ENDPOINTS} config.update(extra_config) sp_metadata_str = create_metadata_from_config_dict(sp_conf) idp_conf["metadata"]["inline"] = [sp_metadata_str] base_url = self.construct_base_url_from_entity_id(idp_conf["entityid"]) - samlfrontend = SAMLFrontend(lambda ctx, internal_req: (ctx, internal_req), - internal_attributes, config, base_url, "saml_frontend") + samlfrontend = SAMLFrontend( + lambda ctx, internal_req: (ctx, internal_req), + internal_attributes, + config, + base_url, + "saml_frontend", + ) samlfrontend.register_endpoints(["saml"]) idp_metadata_str = create_metadata_from_config_dict(samlfrontend.idp_config) @@ -93,14 +125,18 @@ def setup_for_authn_req(self, context, idp_conf, sp_conf, nameid_format=None, re return samlfrontend - def get_auth_response(self, samlfrontend, context, internal_response, sp_conf, idp_metadata_str): + def get_auth_response( + self, samlfrontend, context, internal_response, sp_conf, idp_metadata_str + ): sp_config = SPConfig().load(sp_conf) resp_args = { "name_id_policy": NameIDPolicy(format=NAMEID_FORMAT_TRANSIENT), "in_response_to": None, - "destination": sp_config.endpoint("assertion_consumer_service", binding=BINDING_HTTP_REDIRECT)[0], + "destination": sp_config.endpoint( + "assertion_consumer_service", binding=BINDING_HTTP_REDIRECT + )[0], "sp_entity_id": sp_conf["entityid"], - "binding": BINDING_HTTP_REDIRECT + "binding": BINDING_HTTP_REDIRECT, } request_state = samlfrontend._create_state_data(context, resp_args, "") context.state[samlfrontend.name] = request_state @@ -110,16 +146,27 @@ def get_auth_response(self, samlfrontend, context, internal_response, sp_conf, i sp_conf["metadata"]["inline"].append(idp_metadata_str) fakesp = FakeSP(sp_config) resp_dict = parse_qs(urlparse(resp.message).query) - return fakesp.parse_authn_request_response(resp_dict["SAMLResponse"][0], BINDING_HTTP_REDIRECT) + return fakesp.parse_authn_request_response( + resp_dict["SAMLResponse"][0], BINDING_HTTP_REDIRECT + ) - @pytest.mark.parametrize("conf", [ - None, - {"idp_config_notok": {}, "endpoints": {}}, - {"idp_config": {}, "endpoints_notok": {}} - ]) + @pytest.mark.parametrize( + "conf", + [ + None, + {"idp_config_notok": {}, "endpoints": {}}, + {"idp_config": {}, "endpoints_notok": {}}, + ], + ) def test_config_error_handling(self, conf): with pytest.raises(ValueError): - SAMLFrontend(lambda ctx, req: None, INTERNAL_ATTRIBUTES, conf, "base_url", "saml_frontend") + SAMLFrontend( + lambda ctx, req: None, + INTERNAL_ATTRIBUTES, + conf, + "base_url", + "saml_frontend", + ) def test_register_endpoints(self, idp_conf): """ @@ -132,66 +179,100 @@ def get_path_from_url(url): config = {"idp_config": idp_conf, "endpoints": ENDPOINTS} base_url = self.construct_base_url_from_entity_id(idp_conf["entityid"]) - samlfrontend = SAMLFrontend(lambda context, internal_req: (context, internal_req), - INTERNAL_ATTRIBUTES, config, base_url, "saml_frontend") + samlfrontend = SAMLFrontend( + lambda context, internal_req: (context, internal_req), + INTERNAL_ATTRIBUTES, + config, + base_url, + "saml_frontend", + ) providers = ["foo", "bar"] url_map = samlfrontend.register_endpoints(providers) - all_idp_endpoints = [get_path_from_url(v[0][0]) for v in idp_conf["service"]["idp"]["endpoints"].values()] + all_idp_endpoints = [ + get_path_from_url(v[0][0]) + for v in idp_conf["service"]["idp"]["endpoints"].values() + ] compiled_regex = [re.compile(regex) for regex, _ in url_map] for endp in all_idp_endpoints: assert any(p.match(endp) for p in compiled_regex) def test_handle_authn_request(self, context, idp_conf, sp_conf, internal_response): samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf) - _, internal_req = samlfrontend.handle_authn_request(context, BINDING_HTTP_REDIRECT) + _, internal_req = samlfrontend.handle_authn_request( + context, BINDING_HTTP_REDIRECT + ) assert internal_req.requester == sp_conf["entityid"] resp = samlfrontend.handle_authn_response(context, internal_response) resp_dict = parse_qs(urlparse(resp.message).query) fakesp = FakeSP(SPConfig().load(sp_conf)) - resp = fakesp.parse_authn_request_response(resp_dict["SAMLResponse"][0], - BINDING_HTTP_REDIRECT) + resp = fakesp.parse_authn_request_response( + resp_dict["SAMLResponse"][0], BINDING_HTTP_REDIRECT + ) for key in resp.ava: assert USERS["testuser1"][key] == resp.ava[key] - def test_create_authn_request_with_subject(self, context, idp_conf, sp_conf, internal_response): - name_id_value = 'somenameid' + def test_create_authn_request_with_subject( + self, context, idp_conf, sp_conf, internal_response + ): + name_id_value = "somenameid" name_id = NameID(format=NAMEID_FORMAT_UNSPECIFIED, text=name_id_value) subject = Subject(name_id=name_id) samlfrontend = self.setup_for_authn_req( context, idp_conf, sp_conf, subject=subject ) - _, internal_req = samlfrontend.handle_authn_request(context, BINDING_HTTP_REDIRECT) + _, internal_req = samlfrontend.handle_authn_request( + context, BINDING_HTTP_REDIRECT + ) assert internal_req.subject_id == name_id_value # XXX TODO how should type be handled? # assert internal_req.subject_type == NAMEID_FORMAT_UNSPECIFIED def test_handle_authn_request_without_name_id_policy_default_to_name_id_format_from_metadata( - self, context, idp_conf, sp_conf): - samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf, nameid_format="") - _, internal_req = samlfrontend.handle_authn_request(context, BINDING_HTTP_REDIRECT) - assert internal_req.subject_type == sp_conf["service"]["sp"]["name_id_format"][0] + self, context, idp_conf, sp_conf + ): + samlfrontend = self.setup_for_authn_req( + context, idp_conf, sp_conf, nameid_format="" + ) + _, internal_req = samlfrontend.handle_authn_request( + context, BINDING_HTTP_REDIRECT + ) + assert ( + internal_req.subject_type == sp_conf["service"]["sp"]["name_id_format"][0] + ) def test_handle_authn_request_without_name_id_policy_and_metadata_without_name_id_format( - self, context, idp_conf, sp_conf): + self, context, idp_conf, sp_conf + ): del sp_conf["service"]["sp"]["name_id_format"] - samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf, nameid_format="") - _, internal_req = samlfrontend.handle_authn_request(context, BINDING_HTTP_REDIRECT) + samlfrontend = self.setup_for_authn_req( + context, idp_conf, sp_conf, nameid_format="" + ) + _, internal_req = samlfrontend.handle_authn_request( + context, BINDING_HTTP_REDIRECT + ) assert internal_req.subject_type == NAMEID_FORMAT_TRANSIENT - def test_handle_authn_response_without_relay_state(self, context, idp_conf, sp_conf, internal_response): - samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf, relay_state=None) - _, internal_req = samlfrontend.handle_authn_request(context, BINDING_HTTP_REDIRECT) + def test_handle_authn_response_without_relay_state( + self, context, idp_conf, sp_conf, internal_response + ): + samlfrontend = self.setup_for_authn_req( + context, idp_conf, sp_conf, relay_state=None + ) + _, internal_req = samlfrontend.handle_authn_request( + context, BINDING_HTTP_REDIRECT + ) assert internal_req.requester == sp_conf["entityid"] resp = samlfrontend.handle_authn_response(context, internal_response) resp_dict = parse_qs(urlparse(resp.message).query) fakesp = FakeSP(SPConfig().load(sp_conf)) - resp = fakesp.parse_authn_request_response(resp_dict["SAMLResponse"][0], - BINDING_HTTP_REDIRECT) + resp = fakesp.parse_authn_request_response( + resp_dict["SAMLResponse"][0], BINDING_HTTP_REDIRECT + ) for key in resp.ava: assert USERS["testuser1"][key] == resp.ava[key] @@ -199,11 +280,14 @@ def test_handle_authn_response_without_relay_state(self, context, idp_conf, sp_c assert samlfrontend.name not in context.state def test_handle_authn_response_without_name_id( - self, context, idp_conf, sp_conf, internal_response): + self, context, idp_conf, sp_conf, internal_response + ): samlfrontend = self.setup_for_authn_req( - context, idp_conf, sp_conf, relay_state=None) + context, idp_conf, sp_conf, relay_state=None + ) _, internal_req = samlfrontend.handle_authn_request( - context, BINDING_HTTP_REDIRECT) + context, BINDING_HTTP_REDIRECT + ) # Make sure we are testing the equivalent of a with no # in the . @@ -215,12 +299,15 @@ def test_handle_authn_response_without_name_id( fakesp = FakeSP(SPConfig().load(sp_conf)) resp = fakesp.parse_authn_request_response( - resp_dict["SAMLResponse"][0], BINDING_HTTP_REDIRECT) + resp_dict["SAMLResponse"][0], BINDING_HTTP_REDIRECT + ) # The must not have an empty TextContent. assert resp.name_id.text is not None - def test_get_filter_attributes_with_sp_requested_attributes_without_friendlyname(self, idp_conf): + def test_get_filter_attributes_with_sp_requested_attributes_without_friendlyname( + self, idp_conf + ): sp_metadata_str = """ @@ -242,12 +329,24 @@ def test_get_filter_attributes_with_sp_requested_attributes_without_friendlyname base_url = self.construct_base_url_from_entity_id(idp_conf["entityid"]) conf = {"idp_config": idp_conf, "endpoints": ENDPOINTS} - internal_attributes = {"attributes": {attr_name.lower(): {"saml": [attr_name]} for attr_name in - ["eduPersonTargetedID", "eduPersonPrincipalName", - "eduPersonAffiliation", "mail", "displayName", "sn", - "givenName"]}} # no op mapping for saml attribute names - - samlfrontend = SAMLFrontend(None, internal_attributes, conf, base_url, "saml_frontend") + internal_attributes = { + "attributes": { + attr_name.lower(): {"saml": [attr_name]} + for attr_name in [ + "eduPersonTargetedID", + "eduPersonPrincipalName", + "eduPersonAffiliation", + "mail", + "displayName", + "sn", + "givenName", + ] + } + } # no op mapping for saml attribute names + + samlfrontend = SAMLFrontend( + None, internal_attributes, conf, base_url, "saml_frontend" + ) samlfrontend.register_endpoints(["testprovider"]) internal_req = InternalData( @@ -255,34 +354,61 @@ def test_get_filter_attributes_with_sp_requested_attributes_without_friendlyname requester="http://sp.example.com", requester_name="Example SP", ) - filtered_attributes = samlfrontend._get_approved_attributes(samlfrontend.idp, - samlfrontend.idp.config.getattr( - "policy", "idp"), - internal_req.requester, None) + filtered_attributes = samlfrontend._get_approved_attributes( + samlfrontend.idp, + samlfrontend.idp.config.getattr("policy", "idp"), + internal_req.requester, + None, + ) - assert set(filtered_attributes) == set(["edupersontargetedid", "edupersonprincipalname", - "edupersonaffiliation", "mail", "displayname", "sn", "givenname"]) + assert set(filtered_attributes) == set( + [ + "edupersontargetedid", + "edupersonprincipalname", + "edupersonaffiliation", + "mail", + "displayname", + "sn", + "givenname", + ] + ) - def test_acr_mapping_in_authn_response(self, context, idp_conf, sp_conf, internal_response): + def test_acr_mapping_in_authn_response( + self, context, idp_conf, sp_conf, internal_response + ): eidas_loa_low = "http://eidas.europa.eu/LoA/low" loa = {"": eidas_loa_low} - samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf, extra_config={"acr_mapping": loa}) + samlfrontend = self.setup_for_authn_req( + context, idp_conf, sp_conf, extra_config={"acr_mapping": loa} + ) idp_metadata_str = create_metadata_from_config_dict(samlfrontend.idp_config) - resp = self.get_auth_response(samlfrontend, context, internal_response, sp_conf, idp_metadata_str) + resp = self.get_auth_response( + samlfrontend, context, internal_response, sp_conf, idp_metadata_str + ) assert len(resp.assertion.authn_statement) == 1 - authn_context_class_ref = resp.assertion.authn_statement[0].authn_context.authn_context_class_ref + authn_context_class_ref = resp.assertion.authn_statement[ + 0 + ].authn_context.authn_context_class_ref assert authn_context_class_ref.text == eidas_loa_low - def test_acr_mapping_per_idp_in_authn_response(self, context, idp_conf, sp_conf, internal_response): + def test_acr_mapping_per_idp_in_authn_response( + self, context, idp_conf, sp_conf, internal_response + ): expected_loa = "LoA1" loa = {"": "http://eidas.europa.eu/LoA/low", idp_conf["entityid"]: expected_loa} - samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf, extra_config={"acr_mapping": loa}) + samlfrontend = self.setup_for_authn_req( + context, idp_conf, sp_conf, extra_config={"acr_mapping": loa} + ) idp_metadata_str = create_metadata_from_config_dict(samlfrontend.idp_config) - resp = self.get_auth_response(samlfrontend, context, internal_response, sp_conf, idp_metadata_str) + resp = self.get_auth_response( + samlfrontend, context, internal_response, sp_conf, idp_metadata_str + ) assert len(resp.assertion.authn_statement) == 1 - authn_context_class_ref = resp.assertion.authn_statement[0].authn_context.authn_context_class_ref + authn_context_class_ref = resp.assertion.authn_statement[ + 0 + ].authn_context.authn_context_class_ref assert authn_context_class_ref.text == expected_loa @pytest.mark.parametrize( @@ -290,12 +416,32 @@ def test_acr_mapping_per_idp_in_authn_response(self, context, idp_conf, sp_conf, [ ([""], "swamid", swamid.RELEASE[""]), ([COCO], "edugain", edugain.RELEASE[""] + edugain.RELEASE[COCO]), - ([RESEARCH_AND_SCHOLARSHIP], "refeds", refeds.RELEASE[""] + refeds.RELEASE[RESEARCH_AND_SCHOLARSHIP]), - ([RESEARCH_AND_EDUCATION, EU], "swamid", swamid.RELEASE[""] + swamid.RELEASE[(RESEARCH_AND_EDUCATION, EU)]), - ([RESEARCH_AND_EDUCATION, HEI], "swamid", swamid.RELEASE[""] + swamid.RELEASE[(RESEARCH_AND_EDUCATION, HEI)]), - ([RESEARCH_AND_EDUCATION, NREN], "swamid", swamid.RELEASE[""] + swamid.RELEASE[(RESEARCH_AND_EDUCATION, NREN)]), - ([SFS_1993_1153], "swamid", swamid.RELEASE[""] + swamid.RELEASE[SFS_1993_1153]), - ] + ( + [RESEARCH_AND_SCHOLARSHIP], + "refeds", + refeds.RELEASE[""] + refeds.RELEASE[RESEARCH_AND_SCHOLARSHIP], + ), + ( + [RESEARCH_AND_EDUCATION, EU], + "swamid", + swamid.RELEASE[""] + swamid.RELEASE[(RESEARCH_AND_EDUCATION, EU)], + ), + ( + [RESEARCH_AND_EDUCATION, HEI], + "swamid", + swamid.RELEASE[""] + swamid.RELEASE[(RESEARCH_AND_EDUCATION, HEI)], + ), + ( + [RESEARCH_AND_EDUCATION, NREN], + "swamid", + swamid.RELEASE[""] + swamid.RELEASE[(RESEARCH_AND_EDUCATION, NREN)], + ), + ( + [SFS_1993_1153], + "swamid", + swamid.RELEASE[""] + swamid.RELEASE[SFS_1993_1153], + ), + ], ) def test_respect_sp_entity_categories( self, @@ -305,10 +451,12 @@ def test_respect_sp_entity_categories( expected_attributes, idp_conf, sp_conf, - internal_response + internal_response, ): idp_metadata_str = create_metadata_from_config_dict(idp_conf) - idp_conf["service"]["idp"]["policy"]["default"]["entity_categories"] = [entity_category_module] + idp_conf["service"]["idp"]["policy"]["default"]["entity_categories"] = [ + entity_category_module + ] if all(entity_category): # don't insert empty entity category sp_conf["entity_category"] = entity_category if entity_category == [COCO]: @@ -328,36 +476,66 @@ def test_respect_sp_entity_categories( ) attribute_mapping = {} for expected_attribute in expected_attributes_in_all_entity_categories: - attribute_mapping[expected_attribute.lower()] = {"saml": [expected_attribute]} + attribute_mapping[expected_attribute.lower()] = { + "saml": [expected_attribute] + } internal_attributes = dict(attributes=attribute_mapping) - samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf, internal_attributes=internal_attributes) + samlfrontend = self.setup_for_authn_req( + context, idp_conf, sp_conf, internal_attributes=internal_attributes + ) - user_attributes = {k: "foo" for k in expected_attributes_in_all_entity_categories} - internal_response.attributes = AttributeMapper(internal_attributes).to_internal("saml", user_attributes) + user_attributes = { + k: "foo" for k in expected_attributes_in_all_entity_categories + } + internal_response.attributes = AttributeMapper(internal_attributes).to_internal( + "saml", user_attributes + ) internal_response.requester = sp_conf["entityid"] - resp = self.get_auth_response(samlfrontend, context, internal_response, sp_conf, idp_metadata_str) + resp = self.get_auth_response( + samlfrontend, context, internal_response, sp_conf, idp_metadata_str + ) assert Counter(resp.ava.keys()) == Counter(expected_attributes) - def test_sp_metadata_including_uiinfo_display_name(self, context, idp_conf, sp_conf): + def test_sp_metadata_including_uiinfo_display_name( + self, context, idp_conf, sp_conf + ): sp_conf["service"]["sp"]["ui_info"] = dict(display_name="Test SP") samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf) - display_names = samlfrontend._get_sp_display_name(samlfrontend.idp, sp_conf["entityid"]) + display_names = samlfrontend._get_sp_display_name( + samlfrontend.idp, sp_conf["entityid"] + ) assert display_names[0]["text"] == "Test SP" - def test_sp_metadata_including_uiinfo_without_display_name(self, context, idp_conf, sp_conf): - sp_conf["service"]["sp"]["ui_info"] = dict(information_url="http://info.example.com") + def test_sp_metadata_including_uiinfo_without_display_name( + self, context, idp_conf, sp_conf + ): + sp_conf["service"]["sp"]["ui_info"] = dict( + information_url="http://info.example.com" + ) samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf) - assert samlfrontend._get_sp_display_name(samlfrontend.idp, sp_conf["entityid"]) is None + assert ( + samlfrontend._get_sp_display_name(samlfrontend.idp, sp_conf["entityid"]) + is None + ) def test_sp_metadata_without_uiinfo(self, context, idp_conf, sp_conf): samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf) - assert samlfrontend._get_sp_display_name(samlfrontend.idp, sp_conf["entityid"]) is None + assert ( + samlfrontend._get_sp_display_name(samlfrontend.idp, sp_conf["entityid"]) + is None + ) def test_metadata_endpoint(self, context, idp_conf): conf = {"idp_config": idp_conf, "endpoints": ENDPOINTS} - samlfrontend = SAMLFrontend(lambda ctx, req: None, INTERNAL_ATTRIBUTES, conf, "base_url", "saml_frontend") + samlfrontend = SAMLFrontend( + lambda ctx, req: None, + INTERNAL_ATTRIBUTES, + conf, + "base_url", + "saml_frontend", + ) samlfrontend.register_endpoints(["todo"]) resp = samlfrontend._metadata_endpoint(context) headers = dict(resp.headers) @@ -368,27 +546,48 @@ def test_custom_attribute_release_with_less_attributes_than_entity_category( self, context, idp_conf, sp_conf, internal_response ): idp_metadata_str = create_metadata_from_config_dict(idp_conf) - idp_conf["service"]["idp"]["policy"]["default"]["entity_categories"] = ["swamid"] + idp_conf["service"]["idp"]["policy"]["default"]["entity_categories"] = [ + "swamid" + ] sp_conf["entity_category"] = [SFS_1993_1153] expected_attributes = swamid.RELEASE[SFS_1993_1153] attribute_mapping = {} for expected_attribute in expected_attributes: - attribute_mapping[expected_attribute.lower()] = {"saml": [expected_attribute]} + attribute_mapping[expected_attribute.lower()] = { + "saml": [expected_attribute] + } internal_attributes = dict(attributes=attribute_mapping) user_attributes = {k: "foo" for k in expected_attributes} - internal_response.attributes = AttributeMapper(internal_attributes).to_internal("saml", user_attributes) + internal_response.attributes = AttributeMapper(internal_attributes).to_internal( + "saml", user_attributes + ) - custom_attributes = {idp_conf["entityid"]: {sp_conf["entityid"]: {"exclude": ["norEduPersonNIN"]}}} - samlfrontend = self.setup_for_authn_req(context, idp_conf, sp_conf, internal_attributes=internal_attributes, - extra_config=dict(custom_attribute_release=custom_attributes)) + custom_attributes = { + idp_conf["entityid"]: { + sp_conf["entityid"]: {"exclude": ["norEduPersonNIN"]} + } + } + samlfrontend = self.setup_for_authn_req( + context, + idp_conf, + sp_conf, + internal_attributes=internal_attributes, + extra_config=dict(custom_attribute_release=custom_attributes), + ) internal_response.requester = sp_conf["entityid"] - resp = self.get_auth_response(samlfrontend, context, internal_response, sp_conf, idp_metadata_str) + resp = self.get_auth_response( + samlfrontend, context, internal_response, sp_conf, idp_metadata_str + ) assert len(resp.ava.keys()) == ( len(expected_attributes) - - len(custom_attributes[internal_response.auth_info.issuer][internal_response.requester]["exclude"]) + - len( + custom_attributes[internal_response.auth_info.issuer][ + internal_response.requester + ]["exclude"] + ) ) @@ -399,12 +598,19 @@ class TestSAMLMirrorFrontend: @pytest.fixture(autouse=True) def create_frontend(self, idp_conf): conf = {"idp_config": idp_conf, "endpoints": ENDPOINTS} - self.frontend = SAMLMirrorFrontend(lambda ctx, req: None, INTERNAL_ATTRIBUTES, conf, BASE_URL, - "saml_mirror_frontend") + self.frontend = SAMLMirrorFrontend( + lambda ctx, req: None, + INTERNAL_ATTRIBUTES, + conf, + BASE_URL, + "saml_mirror_frontend", + ) self.frontend.register_endpoints([self.BACKEND]) def assert_dynamic_endpoints(self, sso_endpoints): - endpoint_base_url = "{}/{}/{}".format(BASE_URL, self.BACKEND, self.TARGET_ENTITY_ID) + endpoint_base_url = "{}/{}/{}".format( + BASE_URL, self.BACKEND, self.TARGET_ENTITY_ID + ) expected_endpoints = [] for binding, endpoint in ENDPOINTS["single_sign_on_service"].items(): endp = "{}/{}".format(endpoint_base_url, endpoint) @@ -413,20 +619,28 @@ def assert_dynamic_endpoints(self, sso_endpoints): assert all(sso in sso_endpoints for sso in expected_endpoints) def test_load_endpoints_to_config(self): - idp_config = self.frontend._load_endpoints_to_config(self.BACKEND, self.TARGET_ENTITY_ID) - self.assert_dynamic_endpoints(idp_config["service"]["idp"]["endpoints"]["single_sign_on_service"]) + idp_config = self.frontend._load_endpoints_to_config( + self.BACKEND, self.TARGET_ENTITY_ID + ) + self.assert_dynamic_endpoints( + idp_config["service"]["idp"]["endpoints"]["single_sign_on_service"] + ) def test_load_idp_dynamic_endpoints(self, context): context.path = "{}/{}/sso/redirect".format(self.BACKEND, self.TARGET_ENTITY_ID) context.target_backend = self.BACKEND idp = self.frontend._load_idp_dynamic_endpoints(context) - self.assert_dynamic_endpoints(idp.config._idp_endpoints["single_sign_on_service"]) + self.assert_dynamic_endpoints( + idp.config._idp_endpoints["single_sign_on_service"] + ) def test_load_idp_dynamic_entity_id(self, idp_conf): state = State() state[self.frontend.name] = {"target_entity_id": self.TARGET_ENTITY_ID} idp = self.frontend._load_idp_dynamic_entity_id(state) - assert idp.config.entityid == "{}/{}".format(idp_conf["entityid"], self.TARGET_ENTITY_ID) + assert idp.config.entityid == "{}/{}".format( + idp_conf["entityid"], self.TARGET_ENTITY_ID + ) class TestSAMLVirtualCoFrontend(TestSAMLFrontend): @@ -465,7 +679,7 @@ def frontend(self, idp_conf, sp_conf): collab_org = { "encodeable_name": self.CO, "co_static_saml_attributes": self.CO_STATIC_SAML_ATTRIBUTES, - "co_attribute_scope": self.CO_SCOPE + "co_attribute_scope": self.CO_SCOPE, } # Use the dynamically updated idp_conf fixture, the configured @@ -484,16 +698,19 @@ def frontend(self, idp_conf, sp_conf): internal_attributes["attributes"][self.CO_O] = {"saml": ["o"]} internal_attributes["attributes"][self.CO_C] = {"saml": ["c"]} internal_attributes["attributes"][self.CO_CO] = {"saml": ["co"]} - internal_attributes["attributes"][self.CO_NOREDUORGACRONYM] = ( - {"saml": ["norEduOrgAcronym"]}) + internal_attributes["attributes"][self.CO_NOREDUORGACRONYM] = { + "saml": ["norEduOrgAcronym"] + } # Create, register the endpoints, and then return the frontend # instance. - frontend = SAMLVirtualCoFrontend(lambda ctx, req: None, - internal_attributes, - conf, - BASE_URL, - "saml_virtual_co_frontend") + frontend = SAMLVirtualCoFrontend( + lambda ctx, req: None, + internal_attributes, + conf, + BASE_URL, + "saml_virtual_co_frontend", + ) frontend.register_endpoints([self.BACKEND]) return frontend @@ -518,7 +735,7 @@ def test_create_state_data(self, frontend, context, idp_conf): state = frontend._create_state_data(context, {}, "") assert state[frontend.KEY_CO_NAME] == self.CO - expected_entityid = "{}/{}".format(idp_conf['entityid'], self.CO) + expected_entityid = "{}/{}".format(idp_conf["entityid"], self.CO) assert state[frontend.KEY_CO_ENTITY_ID] == expected_entityid assert state[frontend.KEY_CO_ATTRIBUTE_SCOPE] == self.CO_SCOPE @@ -532,7 +749,7 @@ def test_get_co_name(self, frontend, context): assert co_name == self.CO def test_create_co_virtual_idp(self, frontend, context, idp_conf): - expected_entityid = "{}/{}".format(idp_conf['entityid'], self.CO) + expected_entityid = "{}/{}".format(idp_conf["entityid"], self.CO) endpoint_base_url = "{}/{}/{}".format(BASE_URL, self.BACKEND, self.CO) expected_endpoints = [] @@ -547,41 +764,53 @@ def test_create_co_virtual_idp(self, frontend, context, idp_conf): assert all(sso in sso_endpoints for sso in expected_endpoints) def test_create_co_virtual_idp_with_entity_id_templates(self, frontend, context): - frontend.idp_config['entityid'] = "{}/Saml2IDP/proxy.xml".format(BASE_URL) + frontend.idp_config["entityid"] = "{}/Saml2IDP/proxy.xml".format(BASE_URL) expected_entity_id = "{}/Saml2IDP/proxy.xml/{}".format(BASE_URL, self.CO) idp_server = frontend._create_co_virtual_idp(context) assert idp_server.config.entityid == expected_entity_id - frontend.idp_config['entityid'] = "{}//idp/".format(BASE_URL) - expected_entity_id = "{}/{}/idp/{}".format(BASE_URL, context.target_backend, self.CO) + frontend.idp_config["entityid"] = "{}//idp/".format( + BASE_URL + ) + expected_entity_id = "{}/{}/idp/{}".format( + BASE_URL, context.target_backend, self.CO + ) idp_server = frontend._create_co_virtual_idp(context) assert idp_server.config.entityid == expected_entity_id def test_register_endpoints(self, frontend, context): idp_server = frontend._create_co_virtual_idp(context) url_map = frontend.register_endpoints([self.BACKEND]) - all_idp_endpoints = [urlparse(endpoint[0]).path[1:] for - endpoint in - idp_server.config._idp_endpoints[self.KEY_SSO]] + all_idp_endpoints = [ + urlparse(endpoint[0]).path[1:] + for endpoint in idp_server.config._idp_endpoints[self.KEY_SSO] + ] compiled_regex = [re.compile(regex) for regex, _ in url_map] for endpoint in all_idp_endpoints: assert any(pat.match(endpoint) for pat in compiled_regex) - def test_register_endpoints_throws_error_in_case_duplicate_entity_ids(self, frontend): + def test_register_endpoints_throws_error_in_case_duplicate_entity_ids( + self, frontend + ): with pytest.raises(ValueError): frontend.register_endpoints([self.BACKEND, self.BACKEND_1]) def test_register_endpoints_with_metadata_endpoints(self, frontend, context): - frontend.idp_config['entityid'] = "{}//idp/".format(BASE_URL) - frontend.config['entityid_endpoint'] = True + frontend.idp_config["entityid"] = "{}//idp/".format( + BASE_URL + ) + frontend.config["entityid_endpoint"] = True idp_server_1 = frontend._create_co_virtual_idp(context) context_2 = self._make_context(context, self.BACKEND_1, self.CO) idp_server_2 = frontend._create_co_virtual_idp(context_2) url_map = frontend.register_endpoints([self.BACKEND, self.BACKEND_1]) - expected_idp_endpoints = [urlparse(endpoint[0]).path[1:] for server in [idp_server_1, idp_server_2] - for endpoint in server.config._idp_endpoints[self.KEY_SSO]] + expected_idp_endpoints = [ + urlparse(endpoint[0]).path[1:] + for server in [idp_server_1, idp_server_2] + for endpoint in server.config._idp_endpoints[self.KEY_SSO] + ] for server in [idp_server_1, idp_server_2]: expected_idp_endpoints.append(urlparse(server.config.entityid).path[1:]) @@ -590,8 +819,9 @@ def test_register_endpoints_with_metadata_endpoints(self, frontend, context): for endpoint in expected_idp_endpoints: assert any(pat.match(endpoint) for pat in compiled_regex) - def test_co_static_attributes(self, frontend, context, internal_response, - idp_conf, sp_conf): + def test_co_static_attributes( + self, frontend, context, internal_response, idp_conf, sp_conf + ): # Use the frontend and context fixtures to dynamically create the # proxy IdP server that would be created during a flow. idp_server = frontend._create_co_virtual_idp(context) @@ -633,11 +863,10 @@ def test_co_static_attributes(self, frontend, context, internal_response, "name_id_policy": NameIDPolicy(format=NAMEID_FORMAT_TRANSIENT), "in_response_to": None, "destination": sp_config.endpoint( - "assertion_consumer_service", - binding=BINDING_HTTP_REDIRECT + "assertion_consumer_service", binding=BINDING_HTTP_REDIRECT )[0], "sp_entity_id": sp_conf["entityid"], - "binding": BINDING_HTTP_REDIRECT + "binding": BINDING_HTTP_REDIRECT, } request_state = frontend._create_state_data(context, resp_args, "") context.state[frontend.name] = request_state @@ -652,8 +881,7 @@ def test_co_static_attributes(self, frontend, context, internal_response, class TestSubjectTypeToSamlNameIdFormat: def test_should_default_to_persistent(self): assert ( - subject_type_to_saml_nameid_format("unmatched") - == NAMEID_FORMAT_PERSISTENT + subject_type_to_saml_nameid_format("unmatched") == NAMEID_FORMAT_PERSISTENT ) def test_should_map_persistent(self): @@ -681,11 +909,7 @@ def test_should_map_unspecified(self): ) def test_should_map_public(self): - assert ( - subject_type_to_saml_nameid_format("public") == NAMEID_FORMAT_PERSISTENT - ) + assert subject_type_to_saml_nameid_format("public") == NAMEID_FORMAT_PERSISTENT def test_should_map_pairwise(self): - assert ( - subject_type_to_saml_nameid_format("pairwise") == NAMEID_FORMAT_TRANSIENT - ) + assert subject_type_to_saml_nameid_format("pairwise") == NAMEID_FORMAT_TRANSIENT diff --git a/tests/satosa/metadata_creation/test_description.py b/tests/satosa/metadata_creation/test_description.py index 818d01a03..522b52d10 100644 --- a/tests/satosa/metadata_creation/test_description.py +++ b/tests/satosa/metadata_creation/test_description.py @@ -1,6 +1,11 @@ import pytest -from satosa.metadata_creation.description import ContactPersonDesc, UIInfoDesc, OrganizationDesc, MetadataDescription +from satosa.metadata_creation.description import ( + ContactPersonDesc, + MetadataDescription, + OrganizationDesc, + UIInfoDesc, +) class TestContactPersonDesc(object): @@ -32,10 +37,14 @@ def test_to_dict(self): ui_info = serialized["service"]["idp"]["ui_info"] assert ui_info["description"] == [{"text": "test", "lang": "en"}] assert ui_info["display_name"] == [{"text": "my company", "lang": "en"}] - assert ui_info["logo"] == [{"text": "logo.jpg", "width": 80, "height": 80, "lang": "en"}] + assert ui_info["logo"] == [ + {"text": "logo.jpg", "width": 80, "height": 80, "lang": "en"} + ] assert ui_info["keywords"] == [{"text": ["kw1", "kw2"], "lang": "en"}] assert ui_info["information_url"] == [{"text": "https://test", "lang": "en"}] - assert ui_info["privacy_statement_url"] == [{"text": "https://test", "lang": "en"}] + assert ui_info["privacy_statement_url"] == [ + {"text": "https://test", "lang": "en"} + ] def test_to_dict_for_logo_without_lang(self): desc = UIInfoDesc() diff --git a/tests/satosa/metadata_creation/test_saml_metadata.py b/tests/satosa/metadata_creation/test_saml_metadata.py index 77e8ac1d7..65fe24d32 100644 --- a/tests/satosa/metadata_creation/test_saml_metadata.py +++ b/tests/satosa/metadata_creation/test_saml_metadata.py @@ -2,57 +2,93 @@ from base64 import urlsafe_b64encode import pytest -from saml2.config import SPConfig, Config +from saml2.config import Config, SPConfig from saml2.mdstore import InMemoryMetaData from saml2.metadata import entity_descriptor from saml2.sigver import security_context from saml2.time_util import in_a_while -from satosa.metadata_creation.saml_metadata import create_entity_descriptors, create_signed_entities_descriptor, \ - create_signed_entity_descriptor +from satosa.metadata_creation.saml_metadata import ( + create_entity_descriptors, + create_signed_entities_descriptor, + create_signed_entity_descriptor, +) from satosa.satosa_config import SATOSAConfig from tests.conftest import BASE_URL from tests.util import create_metadata_from_config_dict class TestCreateEntityDescriptors: - def assert_single_sign_on_endpoints_for_saml_frontend(self, entity_descriptor, saml_frontend_config, backend_names): + def assert_single_sign_on_endpoints_for_saml_frontend( + self, entity_descriptor, saml_frontend_config, backend_names + ): metadata = InMemoryMetaData(None, str(entity_descriptor)) metadata.load() - sso = metadata.service(saml_frontend_config["config"]["idp_config"]["entityid"], "idpsso_descriptor", - "single_sign_on_service") + sso = metadata.service( + saml_frontend_config["config"]["idp_config"]["entityid"], + "idpsso_descriptor", + "single_sign_on_service", + ) for backend_name in backend_names: - for binding, path in saml_frontend_config["config"]["endpoints"]["single_sign_on_service"].items(): - sso_urls_for_binding = [endpoint["location"] for endpoint in sso[binding]] + for binding, path in saml_frontend_config["config"]["endpoints"][ + "single_sign_on_service" + ].items(): + sso_urls_for_binding = [ + endpoint["location"] for endpoint in sso[binding] + ] expected_url = "{}/{}/{}".format(BASE_URL, backend_name, path) assert expected_url in sso_urls_for_binding - def assert_single_sign_on_endpoints_for_saml_mirror_frontend(self, entity_descriptors, encoded_target_entity_id, - saml_mirror_frontend_config, backend_names): - expected_entity_id = saml_mirror_frontend_config["config"]["idp_config"][ - "entityid"] + "/" + encoded_target_entity_id + def assert_single_sign_on_endpoints_for_saml_mirror_frontend( + self, + entity_descriptors, + encoded_target_entity_id, + saml_mirror_frontend_config, + backend_names, + ): + expected_entity_id = ( + saml_mirror_frontend_config["config"]["idp_config"]["entityid"] + + "/" + + encoded_target_entity_id + ) metadata = InMemoryMetaData(None, None) for ed in entity_descriptors: metadata.parse(str(ed)) - sso = metadata.service(expected_entity_id, "idpsso_descriptor", "single_sign_on_service") + sso = metadata.service( + expected_entity_id, "idpsso_descriptor", "single_sign_on_service" + ) for backend_name in backend_names: - for binding, path in saml_mirror_frontend_config["config"]["endpoints"]["single_sign_on_service"].items(): - sso_urls_for_binding = [endpoint["location"] for endpoint in sso[binding]] - expected_url = "{}/{}/{}/{}".format(BASE_URL, backend_name, encoded_target_entity_id, path) + for binding, path in saml_mirror_frontend_config["config"]["endpoints"][ + "single_sign_on_service" + ].items(): + sso_urls_for_binding = [ + endpoint["location"] for endpoint in sso[binding] + ] + expected_url = "{}/{}/{}/{}".format( + BASE_URL, backend_name, encoded_target_entity_id, path + ) assert expected_url in sso_urls_for_binding - def assert_assertion_consumer_service_endpoints_for_saml_backend(self, entity_descriptor, saml_backend_config): + def assert_assertion_consumer_service_endpoints_for_saml_backend( + self, entity_descriptor, saml_backend_config + ): metadata = InMemoryMetaData(None, str(entity_descriptor)) metadata.load() - acs = metadata.service(saml_backend_config["config"]["sp_config"]["entityid"], "spsso_descriptor", - "assertion_consumer_service") - for url, binding in saml_backend_config["config"]["sp_config"]["service"]["sp"]["endpoints"][ - "assertion_consumer_service"]: + acs = metadata.service( + saml_backend_config["config"]["sp_config"]["entityid"], + "spsso_descriptor", + "assertion_consumer_service", + ) + for url, binding in saml_backend_config["config"]["sp_config"]["service"]["sp"][ + "endpoints" + ]["assertion_consumer_service"]: assert acs[binding][0]["location"] == url - def test_saml_frontend_with_saml_backend(self, satosa_config_dict, saml_frontend_config, saml_backend_config): + def test_saml_frontend_with_saml_backend( + self, satosa_config_dict, saml_frontend_config, saml_backend_config + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] satosa_config = SATOSAConfig(satosa_config_dict) @@ -61,14 +97,17 @@ def test_saml_frontend_with_saml_backend(self, satosa_config_dict, saml_frontend assert len(frontend_metadata) == 1 assert len(frontend_metadata[saml_frontend_config["name"]]) == 1 entity_descriptor = frontend_metadata[saml_frontend_config["name"]][0] - self.assert_single_sign_on_endpoints_for_saml_frontend(entity_descriptor, saml_frontend_config, - [saml_backend_config["name"]]) + self.assert_single_sign_on_endpoints_for_saml_frontend( + entity_descriptor, saml_frontend_config, [saml_backend_config["name"]] + ) assert len(backend_metadata) == 1 self.assert_assertion_consumer_service_endpoints_for_saml_backend( - backend_metadata[saml_backend_config["name"]][0], - saml_backend_config) + backend_metadata[saml_backend_config["name"]][0], saml_backend_config + ) - def test_saml_frontend_with_oidc_backend(self, satosa_config_dict, saml_frontend_config, oidc_backend_config): + def test_saml_frontend_with_oidc_backend( + self, satosa_config_dict, saml_frontend_config, oidc_backend_config + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [oidc_backend_config] satosa_config = SATOSAConfig(satosa_config_dict) @@ -77,39 +116,57 @@ def test_saml_frontend_with_oidc_backend(self, satosa_config_dict, saml_frontend assert len(frontend_metadata) == 1 assert len(frontend_metadata[saml_frontend_config["name"]]) == 1 entity_descriptor = frontend_metadata[saml_frontend_config["name"]][0] - self.assert_single_sign_on_endpoints_for_saml_frontend(entity_descriptor, saml_frontend_config, - [oidc_backend_config["name"]]) + self.assert_single_sign_on_endpoints_for_saml_frontend( + entity_descriptor, saml_frontend_config, [oidc_backend_config["name"]] + ) # OIDC backend does not produce any SAML metadata assert not backend_metadata - def test_saml_frontend_with_multiple_backends(self, satosa_config_dict, saml_frontend_config, saml_backend_config, - oidc_backend_config): + def test_saml_frontend_with_multiple_backends( + self, + satosa_config_dict, + saml_frontend_config, + saml_backend_config, + oidc_backend_config, + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_frontend_config] - satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config, oidc_backend_config] + satosa_config_dict["BACKEND_MODULES"] = [ + saml_backend_config, + oidc_backend_config, + ] satosa_config = SATOSAConfig(satosa_config_dict) frontend_metadata, backend_metadata = create_entity_descriptors(satosa_config) assert len(frontend_metadata) == 1 assert len(frontend_metadata[saml_frontend_config["name"]]) == 1 entity_descriptor = frontend_metadata[saml_frontend_config["name"]][0] - self.assert_single_sign_on_endpoints_for_saml_frontend(entity_descriptor, saml_frontend_config, - [saml_backend_config["name"], - oidc_backend_config["name"]]) + self.assert_single_sign_on_endpoints_for_saml_frontend( + entity_descriptor, + saml_frontend_config, + [saml_backend_config["name"], oidc_backend_config["name"]], + ) # only the SAML backend produces SAML metadata assert len(backend_metadata) == 1 self.assert_assertion_consumer_service_endpoints_for_saml_backend( - backend_metadata[saml_backend_config["name"]][0], - saml_backend_config) - - def test_saml_mirror_frontend_with_saml_backend_with_multiple_target_providers(self, satosa_config_dict, idp_conf, - saml_mirror_frontend_config, - saml_backend_config): + backend_metadata[saml_backend_config["name"]][0], saml_backend_config + ) + + def test_saml_mirror_frontend_with_saml_backend_with_multiple_target_providers( + self, + satosa_config_dict, + idp_conf, + saml_mirror_frontend_config, + saml_backend_config, + ): idp_conf2 = copy.deepcopy(idp_conf) idp_conf2["entityid"] = "https://idp2.example.com" satosa_config_dict["FRONTEND_MODULES"] = [saml_mirror_frontend_config] - saml_backend_config["config"]["sp_config"]["metadata"] = {"inline": [create_metadata_from_config_dict(idp_conf), - create_metadata_from_config_dict( - idp_conf2)]} + saml_backend_config["config"]["sp_config"]["metadata"] = { + "inline": [ + create_metadata_from_config_dict(idp_conf), + create_metadata_from_config_dict(idp_conf2), + ] + } satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] satosa_config = SATOSAConfig(satosa_config_dict) frontend_metadata, backend_metadata = create_entity_descriptors(satosa_config) @@ -119,17 +176,23 @@ def test_saml_mirror_frontend_with_saml_backend_with_multiple_target_providers(s entity_descriptors = frontend_metadata[saml_mirror_frontend_config["name"]] for target_entity_id in [idp_conf["entityid"], idp_conf2["entityid"]]: - encoded_target_entity_id = urlsafe_b64encode(target_entity_id.encode("utf-8")).decode("utf-8") - self.assert_single_sign_on_endpoints_for_saml_mirror_frontend(entity_descriptors, encoded_target_entity_id, - saml_mirror_frontend_config, - [saml_backend_config["name"]]) + encoded_target_entity_id = urlsafe_b64encode( + target_entity_id.encode("utf-8") + ).decode("utf-8") + self.assert_single_sign_on_endpoints_for_saml_mirror_frontend( + entity_descriptors, + encoded_target_entity_id, + saml_mirror_frontend_config, + [saml_backend_config["name"]], + ) assert len(backend_metadata) == 1 self.assert_assertion_consumer_service_endpoints_for_saml_backend( - backend_metadata[saml_backend_config["name"]][0], - saml_backend_config) + backend_metadata[saml_backend_config["name"]][0], saml_backend_config + ) - def test_saml_mirror_frontend_with_oidc_backend(self, satosa_config_dict, saml_mirror_frontend_config, - oidc_backend_config): + def test_saml_mirror_frontend_with_oidc_backend( + self, satosa_config_dict, saml_mirror_frontend_config, oidc_backend_config + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_mirror_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [oidc_backend_config] satosa_config = SATOSAConfig(satosa_config_dict) @@ -139,45 +202,76 @@ def test_saml_mirror_frontend_with_oidc_backend(self, satosa_config_dict, saml_m assert len(frontend_metadata[saml_mirror_frontend_config["name"]]) == 1 entity_descriptors = frontend_metadata[saml_mirror_frontend_config["name"]] target_entity_id = oidc_backend_config["config"]["provider_metadata"]["issuer"] - encoded_target_entity_id = urlsafe_b64encode(target_entity_id.encode("utf-8")).decode("utf-8") - self.assert_single_sign_on_endpoints_for_saml_mirror_frontend(entity_descriptors, encoded_target_entity_id, - saml_mirror_frontend_config, - [oidc_backend_config["name"]]) + encoded_target_entity_id = urlsafe_b64encode( + target_entity_id.encode("utf-8") + ).decode("utf-8") + self.assert_single_sign_on_endpoints_for_saml_mirror_frontend( + entity_descriptors, + encoded_target_entity_id, + saml_mirror_frontend_config, + [oidc_backend_config["name"]], + ) # OIDC backend does not produce any SAML metadata assert not backend_metadata - def test_saml_mirror_frontend_with_multiple_backends(self, satosa_config_dict, idp_conf, - saml_mirror_frontend_config, - saml_backend_config, oidc_backend_config): + def test_saml_mirror_frontend_with_multiple_backends( + self, + satosa_config_dict, + idp_conf, + saml_mirror_frontend_config, + saml_backend_config, + oidc_backend_config, + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_mirror_frontend_config] saml_backend_config["config"]["sp_config"]["metadata"] = { - "inline": [create_metadata_from_config_dict(idp_conf)]} - satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config, oidc_backend_config] + "inline": [create_metadata_from_config_dict(idp_conf)] + } + satosa_config_dict["BACKEND_MODULES"] = [ + saml_backend_config, + oidc_backend_config, + ] satosa_config = SATOSAConfig(satosa_config_dict) frontend_metadata, backend_metadata = create_entity_descriptors(satosa_config) assert len(frontend_metadata) == 1 assert len(frontend_metadata[saml_mirror_frontend_config["name"]]) == 2 - params = zip([idp_conf["entityid"], oidc_backend_config["config"]["provider_metadata"]["issuer"]], - [saml_backend_config["name"], oidc_backend_config["name"]]) + params = zip( + [ + idp_conf["entityid"], + oidc_backend_config["config"]["provider_metadata"]["issuer"], + ], + [saml_backend_config["name"], oidc_backend_config["name"]], + ) entity_descriptors = frontend_metadata[saml_mirror_frontend_config["name"]] for target_entity_id, backend_name in params: - encoded_target_entity_id = urlsafe_b64encode(target_entity_id.encode("utf-8")).decode("utf-8") - self.assert_single_sign_on_endpoints_for_saml_mirror_frontend(entity_descriptors, encoded_target_entity_id, - saml_mirror_frontend_config, - [backend_name]) + encoded_target_entity_id = urlsafe_b64encode( + target_entity_id.encode("utf-8") + ).decode("utf-8") + self.assert_single_sign_on_endpoints_for_saml_mirror_frontend( + entity_descriptors, + encoded_target_entity_id, + saml_mirror_frontend_config, + [backend_name], + ) # only the SAML backend produces SAML metadata assert len(backend_metadata) self.assert_assertion_consumer_service_endpoints_for_saml_backend( - backend_metadata[saml_backend_config["name"]][0], - saml_backend_config) - - def test_two_saml_frontends(self, satosa_config_dict, saml_frontend_config, saml_mirror_frontend_config, - oidc_backend_config): - - satosa_config_dict["FRONTEND_MODULES"] = [saml_frontend_config, saml_mirror_frontend_config] + backend_metadata[saml_backend_config["name"]][0], saml_backend_config + ) + + def test_two_saml_frontends( + self, + satosa_config_dict, + saml_frontend_config, + saml_mirror_frontend_config, + oidc_backend_config, + ): + satosa_config_dict["FRONTEND_MODULES"] = [ + saml_frontend_config, + saml_mirror_frontend_config, + ] satosa_config_dict["BACKEND_MODULES"] = [oidc_backend_config] satosa_config = SATOSAConfig(satosa_config_dict) frontend_metadata, backend_metadata = create_entity_descriptors(satosa_config) @@ -187,27 +281,37 @@ def test_two_saml_frontends(self, satosa_config_dict, saml_frontend_config, saml saml_entities = frontend_metadata[saml_frontend_config["name"]] assert len(saml_entities) == 1 entity_descriptor = saml_entities[0] - self.assert_single_sign_on_endpoints_for_saml_frontend(entity_descriptor, saml_frontend_config, - [oidc_backend_config["name"]]) + self.assert_single_sign_on_endpoints_for_saml_frontend( + entity_descriptor, saml_frontend_config, [oidc_backend_config["name"]] + ) mirrored_saml_entities = frontend_metadata[saml_mirror_frontend_config["name"]] assert len(mirrored_saml_entities) == 1 target_entity_id = oidc_backend_config["config"]["provider_metadata"]["issuer"] - encoded_target_entity_id = urlsafe_b64encode(target_entity_id.encode("utf-8")).decode("utf-8") - self.assert_single_sign_on_endpoints_for_saml_mirror_frontend(mirrored_saml_entities, encoded_target_entity_id, - saml_mirror_frontend_config, - [oidc_backend_config["name"]]) + encoded_target_entity_id = urlsafe_b64encode( + target_entity_id.encode("utf-8") + ).decode("utf-8") + self.assert_single_sign_on_endpoints_for_saml_mirror_frontend( + mirrored_saml_entities, + encoded_target_entity_id, + saml_mirror_frontend_config, + [oidc_backend_config["name"]], + ) # OIDC backend does not produce any SAML metadata assert not backend_metadata - def test_create_mirrored_metadata_does_not_contain_target_contact_info(self, satosa_config_dict, idp_conf, - saml_mirror_frontend_config, - saml_backend_config): - + def test_create_mirrored_metadata_does_not_contain_target_contact_info( + self, + satosa_config_dict, + idp_conf, + saml_mirror_frontend_config, + saml_backend_config, + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_mirror_frontend_config] saml_backend_config["config"]["sp_config"]["metadata"] = { - "inline": [create_metadata_from_config_dict(idp_conf)]} + "inline": [create_metadata_from_config_dict(idp_conf)] + } satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] satosa_config = SATOSAConfig(satosa_config_dict) frontend_metadata, backend_metadata = create_entity_descriptors(satosa_config) @@ -219,18 +323,40 @@ def test_create_mirrored_metadata_does_not_contain_target_contact_info(self, sat entity_info = list(metadata.values())[0] expected_entity_info = saml_mirror_frontend_config["config"]["idp_config"] - assert len(entity_info["contact_person"]) == len(expected_entity_info["contact_person"]) + assert len(entity_info["contact_person"]) == len( + expected_entity_info["contact_person"] + ) for i, contact in enumerate(expected_entity_info["contact_person"]): - assert entity_info["contact_person"][i]["contact_type"] == contact["contact_type"] - assert entity_info["contact_person"][i]["email_address"][0]["text"] == contact["email_address"][0] - assert entity_info["contact_person"][i]["given_name"]["text"] == contact["given_name"] - assert entity_info["contact_person"][i]["sur_name"]["text"] == contact["sur_name"] + assert ( + entity_info["contact_person"][i]["contact_type"] + == contact["contact_type"] + ) + assert ( + entity_info["contact_person"][i]["email_address"][0]["text"] + == contact["email_address"][0] + ) + assert ( + entity_info["contact_person"][i]["given_name"]["text"] + == contact["given_name"] + ) + assert ( + entity_info["contact_person"][i]["sur_name"]["text"] + == contact["sur_name"] + ) expected_org_info = expected_entity_info["organization"] - assert entity_info["organization"]["organization_display_name"][0]["text"] == \ - expected_org_info["display_name"][0][0] - assert entity_info["organization"]["organization_name"][0]["text"] == expected_org_info["name"][0][0] - assert entity_info["organization"]["organization_url"][0]["text"] == expected_org_info["url"][0][0] + assert ( + entity_info["organization"]["organization_display_name"][0]["text"] + == expected_org_info["display_name"][0][0] + ) + assert ( + entity_info["organization"]["organization_name"][0]["text"] + == expected_org_info["name"][0][0] + ) + assert ( + entity_info["organization"]["organization_url"][0]["text"] + == expected_org_info["url"][0][0] + ) class TestCreateSignedEntitiesDescriptor: @@ -251,8 +377,12 @@ def signature_security_context(self, cert_and_key): conf.key_file = cert_and_key[1] return security_context(conf) - def test_signed_metadata(self, entity_desc, signature_security_context, verification_security_context): - signed_metadata = create_signed_entities_descriptor([entity_desc, entity_desc], signature_security_context) + def test_signed_metadata( + self, entity_desc, signature_security_context, verification_security_context + ): + signed_metadata = create_signed_entities_descriptor( + [entity_desc, entity_desc], signature_security_context + ) md = InMemoryMetaData(None, security=verification_security_context) md.parse(signed_metadata) @@ -263,8 +393,9 @@ def test_signed_metadata(self, entity_desc, signature_security_context, verifica def test_valid_for(self, entity_desc, signature_security_context): valid_for = 4 # metadata valid for 4 hours expected_validity = in_a_while(hours=valid_for) - signed_metadata = create_signed_entities_descriptor([entity_desc], signature_security_context, - valid_for=valid_for) + signed_metadata = create_signed_entities_descriptor( + [entity_desc], signature_security_context, valid_for=valid_for + ) md = InMemoryMetaData(None) md.parse(signed_metadata) @@ -289,8 +420,12 @@ def signature_security_context(self, cert_and_key): conf.key_file = cert_and_key[1] return security_context(conf) - def test_signed_metadata(self, entity_desc, signature_security_context, verification_security_context): - signed_metadata = create_signed_entity_descriptor(entity_desc, signature_security_context) + def test_signed_metadata( + self, entity_desc, signature_security_context, verification_security_context + ): + signed_metadata = create_signed_entity_descriptor( + entity_desc, signature_security_context + ) md = InMemoryMetaData(None, security=verification_security_context) md.parse(signed_metadata) @@ -301,8 +436,9 @@ def test_signed_metadata(self, entity_desc, signature_security_context, verifica def test_valid_for(self, entity_desc, signature_security_context): valid_for = 4 # metadata valid for 4 hours expected_validity = in_a_while(hours=valid_for) - signed_metadata = create_signed_entity_descriptor(entity_desc, signature_security_context, - valid_for=valid_for) + signed_metadata = create_signed_entity_descriptor( + entity_desc, signature_security_context, valid_for=valid_for + ) md = InMemoryMetaData(None) md.parse(signed_metadata) diff --git a/tests/satosa/micro_services/test_account_linking.py b/tests/satosa/micro_services/test_account_linking.py index 859f3517d..e7b437b60 100644 --- a/tests/satosa/micro_services/test_account_linking.py +++ b/tests/satosa/micro_services/test_account_linking.py @@ -3,21 +3,18 @@ import pytest import requests - import responses -from responses import matchers - -from jwkest.jwk import rsa_load, RSAKey +from jwkest.jwk import RSAKey, rsa_load from jwkest.jws import JWS +from responses import matchers from satosa.exception import SATOSAAuthenticationError -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.internal import AuthenticationInformation, InternalData from satosa.micro_services.account_linking import AccountLinking from satosa.response import Redirect -class TestAccountLinking(): +class TestAccountLinking: @pytest.fixture def internal_response(self): auth_info = AuthenticationInformation("auth_class_ref", "timestamp", "issuer") @@ -35,19 +32,27 @@ def account_linking_config(self, signing_key_path): @pytest.fixture(autouse=True) def create_account_linking(self, account_linking_config): - self.account_linking = AccountLinking(account_linking_config, name="AccountLinking", - base_url="https://satosa.example.com") + self.account_linking = AccountLinking( + account_linking_config, + name="AccountLinking", + base_url="https://satosa.example.com", + ) self.account_linking.next = lambda ctx, data: data @responses.activate - def test_existing_account_linking_with_known_known_uuid(self, account_linking_config, internal_response, context): + def test_existing_account_linking_with_known_known_uuid( + self, account_linking_config, internal_response, context + ): uuid = "uuid" data = { "idp": internal_response.auth_info.issuer, "id": internal_response.subject_id, - "redirect_endpoint": self.account_linking.base_url + "/account_linking/handle_account_linking" + "redirect_endpoint": self.account_linking.base_url + + "/account_linking/handle_account_linking", } - key = RSAKey(key=rsa_load(account_linking_config["sign_key"]), use="sig", alg="RS256") + key = RSAKey( + key=rsa_load(account_linking_config["sign_key"]), use="sig", alg="RS256" + ) jws = JWS(json.dumps(data), alg=key.alg).sign_compact([key]) url = "%s/get_id" % account_linking_config["api_url"] params = {"jwt": jws} @@ -71,7 +76,7 @@ def test_full_flow(self, account_linking_config, internal_response, context): "%s/get_id" % account_linking_config["api_url"], status=404, body=ticket, - content_type="text/html" + content_type="text/html", ) result = self.account_linking.process(context, internal_response) assert isinstance(result, Redirect) @@ -80,9 +85,12 @@ def test_full_flow(self, account_linking_config, internal_response, context): data = { "idp": internal_response.auth_info.issuer, "id": internal_response.subject_id, - "redirect_endpoint": self.account_linking.base_url + "/account_linking/handle_account_linking" + "redirect_endpoint": self.account_linking.base_url + + "/account_linking/handle_account_linking", } - key = RSAKey(key=rsa_load(account_linking_config["sign_key"]), use="sig", alg="RS256") + key = RSAKey( + key=rsa_load(account_linking_config["sign_key"]), use="sig", alg="RS256" + ) jws = JWS(json.dumps(data), alg=key.alg).sign_compact([key]) uuid = "uuid" with responses.RequestsMock() as rsps: @@ -101,14 +109,16 @@ def test_full_flow(self, account_linking_config, internal_response, context): assert internal_response.subject_id == uuid @responses.activate - def test_account_linking_failed(self, account_linking_config, internal_response, context): + def test_account_linking_failed( + self, account_linking_config, internal_response, context + ): ticket = "ticket" responses.add( responses.GET, "%s/get_id" % account_linking_config["api_url"], status=404, body=ticket, - content_type="text/html" + content_type="text/html", ) issuer_user_id = internal_response.subject_id result = self.account_linking.process(context, internal_response) @@ -117,26 +127,33 @@ def test_account_linking_failed(self, account_linking_config, internal_response, # account linking endpoint still does not return an id internal_response = self.account_linking._handle_al_response(context) - #Verify that we kept the subject_id the issuer sent us + # Verify that we kept the subject_id the issuer sent us assert internal_response.subject_id == issuer_user_id @responses.activate - def test_manage_al_handle_failed_connection(self, account_linking_config, internal_response, context): + def test_manage_al_handle_failed_connection( + self, account_linking_config, internal_response, context + ): exception = requests.ConnectionError("No connection") - responses.add(responses.GET, "%s/get_id" % account_linking_config["api_url"], - body=exception) + responses.add( + responses.GET, + "%s/get_id" % account_linking_config["api_url"], + body=exception, + ) with pytest.raises(SATOSAAuthenticationError): self.account_linking.process(context, internal_response) - @pytest.mark.parametrize("http_status", [ - 400, 401, 500 - ]) + @pytest.mark.parametrize("http_status", [400, 401, 500]) @responses.activate - def test_manage_al_handle_bad_response_status(self, http_status, account_linking_config, internal_response, - context): - responses.add(responses.GET, "%s/get_id" % account_linking_config["api_url"], - status=http_status) + def test_manage_al_handle_bad_response_status( + self, http_status, account_linking_config, internal_response, context + ): + responses.add( + responses.GET, + "%s/get_id" % account_linking_config["api_url"], + status=http_status, + ) with pytest.raises(SATOSAAuthenticationError): self.account_linking.process(context, internal_response) diff --git a/tests/satosa/micro_services/test_attribute_authorization.py b/tests/satosa/micro_services/test_attribute_authorization.py index 6fb277d15..de9f1e747 100644 --- a/tests/satosa/micro_services/test_attribute_authorization.py +++ b/tests/satosa/micro_services/test_attribute_authorization.py @@ -1,9 +1,10 @@ import pytest -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData -from satosa.micro_services.attribute_authorization import AttributeAuthorization -from satosa.exception import SATOSAAuthenticationError + from satosa.context import Context +from satosa.exception import SATOSAAuthenticationError +from satosa.internal import AuthenticationInformation, InternalData +from satosa.micro_services.attribute_authorization import AttributeAuthorization + class TestAttributeAuthorization: def create_authz_service( @@ -27,9 +28,7 @@ def create_authz_service( return authz_service def test_authz_allow_success(self): - attribute_allow = { - "": { "default": {"a0": ['.+@.+']} } - } + attribute_allow = {"": {"default": {"a0": [".+@.+"]}}} attribute_deny = {} authz_service = self.create_authz_service(attribute_allow, attribute_deny) resp = InternalData(auth_info=AuthenticationInformation()) @@ -37,16 +36,14 @@ def test_authz_allow_success(self): "a0": ["test@example.com"], } try: - ctx = Context() - ctx.state = dict() - authz_service.process(ctx, resp) + ctx = Context() + ctx.state = dict() + authz_service.process(ctx, resp) except SATOSAAuthenticationError: - assert False + assert False def test_authz_allow_fail(self): - attribute_allow = { - "": { "default": {"a0": ['foo1','foo2']} } - } + attribute_allow = {"": {"default": {"a0": ["foo1", "foo2"]}}} attribute_deny = {} authz_service = self.create_authz_service(attribute_allow, attribute_deny) resp = InternalData(auth_info=AuthenticationInformation()) @@ -54,45 +51,40 @@ def test_authz_allow_fail(self): "a0": ["bar"], } with pytest.raises(SATOSAAuthenticationError): - ctx = Context() - ctx.state = dict() - authz_service.process(ctx, resp) + ctx = Context() + ctx.state = dict() + authz_service.process(ctx, resp) def test_authz_allow_missing(self): - attribute_allow = { - "": { "default": {"a0": ['foo1','foo2']} } - } + attribute_allow = {"": {"default": {"a0": ["foo1", "foo2"]}}} attribute_deny = {} - authz_service = self.create_authz_service(attribute_allow, attribute_deny, force_attributes_presence_on_allow=True) + authz_service = self.create_authz_service( + attribute_allow, attribute_deny, force_attributes_presence_on_allow=True + ) resp = InternalData(auth_info=AuthenticationInformation()) - resp.attributes = { - } + resp.attributes = {} with pytest.raises(SATOSAAuthenticationError): - ctx = Context() - ctx.state = dict() - authz_service.process(ctx, resp) + ctx = Context() + ctx.state = dict() + authz_service.process(ctx, resp) def test_authz_allow_second(self): - attribute_allow = { - "": { "default": {"a0": ['foo1','foo2']} } - } + attribute_allow = {"": {"default": {"a0": ["foo1", "foo2"]}}} attribute_deny = {} authz_service = self.create_authz_service(attribute_allow, attribute_deny) resp = InternalData(auth_info=AuthenticationInformation()) resp.attributes = { - "a0": ["foo2","kaka"], + "a0": ["foo2", "kaka"], } try: - ctx = Context() - ctx.state = dict() - authz_service.process(ctx, resp) + ctx = Context() + ctx.state = dict() + authz_service.process(ctx, resp) except SATOSAAuthenticationError: - assert False + assert False def test_authz_deny_success(self): - attribute_deny = { - "": { "default": {"a0": ['foo1','foo2']} } - } + attribute_deny = {"": {"default": {"a0": ["foo1", "foo2"]}}} attribute_allow = {} authz_service = self.create_authz_service(attribute_allow, attribute_deny) resp = InternalData(auth_info=AuthenticationInformation()) @@ -100,14 +92,12 @@ def test_authz_deny_success(self): "a0": ["foo2"], } with pytest.raises(SATOSAAuthenticationError): - ctx = Context() - ctx.state = dict() - authz_service.process(ctx, resp) + ctx = Context() + ctx.state = dict() + authz_service.process(ctx, resp) def test_authz_deny_fail(self): - attribute_deny = { - "": { "default": {"a0": ['foo1','foo2']} } - } + attribute_deny = {"": {"default": {"a0": ["foo1", "foo2"]}}} attribute_allow = {} authz_service = self.create_authz_service(attribute_allow, attribute_deny) resp = InternalData(auth_info=AuthenticationInformation()) @@ -115,8 +105,8 @@ def test_authz_deny_fail(self): "a0": ["foo3"], } try: - ctx = Context() - ctx.state = dict() - authz_service.process(ctx, resp) + ctx = Context() + ctx.state = dict() + authz_service.process(ctx, resp) except SATOSAAuthenticationError: - assert False + assert False diff --git a/tests/satosa/micro_services/test_attribute_generation.py b/tests/satosa/micro_services/test_attribute_generation.py index 67f669417..c4d8153e4 100644 --- a/tests/satosa/micro_services/test_attribute_generation.py +++ b/tests/satosa/micro_services/test_attribute_generation.py @@ -1,20 +1,20 @@ -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData -from satosa.micro_services.attribute_generation import AddSyntheticAttributes from satosa.context import Context +from satosa.internal import AuthenticationInformation, InternalData +from satosa.micro_services.attribute_generation import AddSyntheticAttributes + class TestAddSyntheticAttributes: def create_syn_service(self, synthetic_attributes): - authz_service = AddSyntheticAttributes(config=dict(synthetic_attributes=synthetic_attributes), - name="test_gen", - base_url="https://satosa.example.com") + authz_service = AddSyntheticAttributes( + config=dict(synthetic_attributes=synthetic_attributes), + name="test_gen", + base_url="https://satosa.example.com", + ) authz_service.next = lambda ctx, data: data return authz_service def test_generate_static(self): - synthetic_attributes = { - "": { "default": {"a0": "value1;value2" }} - } + synthetic_attributes = {"": {"default": {"a0": "value1;value2"}}} authz_service = self.create_syn_service(synthetic_attributes) resp = InternalData(auth_info=AuthenticationInformation()) resp.attributes = { @@ -23,49 +23,47 @@ def test_generate_static(self): ctx = Context() ctx.state = dict() authz_service.process(ctx, resp) - assert("value1" in resp.attributes['a0']) - assert("value2" in resp.attributes['a0']) - assert("test@example.com" in resp.attributes['a1']) + assert "value1" in resp.attributes["a0"] + assert "value2" in resp.attributes["a0"] + assert "test@example.com" in resp.attributes["a1"] def test_generate_mustache1(self): - synthetic_attributes = { - "": { "default": {"a0": "{{kaka}}#{{eppn.scope}}" }} - } + synthetic_attributes = {"": {"default": {"a0": "{{kaka}}#{{eppn.scope}}"}}} authz_service = self.create_syn_service(synthetic_attributes) resp = InternalData(auth_info=AuthenticationInformation()) resp.attributes = { "kaka": ["kaka1"], - "eppn": ["a@example.com","b@example.com"] + "eppn": ["a@example.com", "b@example.com"], } ctx = Context() ctx.state = dict() authz_service.process(ctx, resp) - assert("kaka1#example.com" in resp.attributes['a0']) - assert("kaka1" in resp.attributes['kaka']) - assert("a@example.com" in resp.attributes['eppn']) - assert("b@example.com" in resp.attributes['eppn']) + assert "kaka1#example.com" in resp.attributes["a0"] + assert "kaka1" in resp.attributes["kaka"] + assert "a@example.com" in resp.attributes["eppn"] + assert "b@example.com" in resp.attributes["eppn"] def test_generate_mustache2(self): synthetic_attributes = { - "": { "default": {"a0": "{{kaka.first}}#{{eppn.scope}}" }} + "": {"default": {"a0": "{{kaka.first}}#{{eppn.scope}}"}} } authz_service = self.create_syn_service(synthetic_attributes) resp = InternalData(auth_info=AuthenticationInformation()) resp.attributes = { - "kaka": ["kaka1","kaka2"], - "eppn": ["a@example.com","b@example.com"] + "kaka": ["kaka1", "kaka2"], + "eppn": ["a@example.com", "b@example.com"], } ctx = Context() ctx.state = dict() authz_service.process(ctx, resp) - assert("kaka1#example.com" in resp.attributes['a0']) - assert("kaka1" in resp.attributes['kaka']) - assert("a@example.com" in resp.attributes['eppn']) - assert("b@example.com" in resp.attributes['eppn']) + assert "kaka1#example.com" in resp.attributes["a0"] + assert "kaka1" in resp.attributes["kaka"] + assert "a@example.com" in resp.attributes["eppn"] + assert "b@example.com" in resp.attributes["eppn"] def test_generate_mustache_empty_attribute(self): synthetic_attributes = { - "": {"default": {"a0": "{{kaka.first}}#{{eppn.scope}}"}} + "": {"default": {"a0": "{{kaka.first}}#{{eppn.scope}}"}} } authz_service = self.create_syn_service(synthetic_attributes) resp = InternalData(auth_info=AuthenticationInformation()) @@ -76,6 +74,6 @@ def test_generate_mustache_empty_attribute(self): ctx = Context() ctx.state = dict() authz_service.process(ctx, resp) - assert("kaka1#" in resp.attributes['a0']) - assert("kaka1" in resp.attributes['kaka']) - assert("kaka2" in resp.attributes['kaka']) + assert "kaka1#" in resp.attributes["a0"] + assert "kaka1" in resp.attributes["kaka"] + assert "kaka2" in resp.attributes["kaka"] diff --git a/tests/satosa/micro_services/test_attribute_modifications.py b/tests/satosa/micro_services/test_attribute_modifications.py index 41ce8a7c0..b1576ea73 100644 --- a/tests/satosa/micro_services/test_attribute_modifications.py +++ b/tests/satosa/micro_services/test_attribute_modifications.py @@ -1,43 +1,37 @@ import pytest -from tests.util import FakeIdP, create_metadata_from_config_dict, FakeSP -from saml2.mdstore import MetadataStore from saml2.config import Config +from saml2.mdstore import MetadataStore + from satosa.context import Context from satosa.exception import SATOSAError -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.internal import AuthenticationInformation, InternalData from satosa.micro_services.attribute_modifications import FilterAttributeValues +from tests.util import FakeIdP, FakeSP, create_metadata_from_config_dict class TestFilterAttributeValues: def create_filter_service(self, attribute_filters): - filter_service = FilterAttributeValues(config=dict(attribute_filters=attribute_filters), name="test_filter", - base_url="https://satosa.example.com") + filter_service = FilterAttributeValues( + config=dict(attribute_filters=attribute_filters), + name="test_filter", + base_url="https://satosa.example.com", + ) filter_service.next = lambda ctx, data: data return filter_service def create_idp_metadata_conf_with_shibmd_scopes(self, idp_entityid, shibmd_scopes): - idp_conf = { - "entityid": idp_entityid, - "service": { - "idp":{} - } - } + idp_conf = {"entityid": idp_entityid, "service": {"idp": {}}} if shibmd_scopes is not None: idp_conf["service"]["idp"]["scope"] = shibmd_scopes - metadata_conf = { - "inline": [create_metadata_from_config_dict(idp_conf)] - } + metadata_conf = {"inline": [create_metadata_from_config_dict(idp_conf)]} return metadata_conf def test_filter_all_attributes_from_all_target_providers_for_all_requesters(self): attribute_filters = { "": { # all providers - "": { # all requesters - "": "foo:bar" # all attributes - } + "": {"": "foo:bar"} # all requesters # all attributes } } filter_service = self.create_filter_service(attribute_filters) @@ -46,19 +40,17 @@ def test_filter_all_attributes_from_all_target_providers_for_all_requesters(self resp.attributes = { "a1": ["abc:xyz"], "a2": ["foo:bar", "1:foo:bar:2"], - "a3": ["a:foo:bar:b"] + "a3": ["a:foo:bar:b"], } filtered = filter_service.process(None, resp) - assert filtered.attributes == {"a1": [], "a2": ["foo:bar", "1:foo:bar:2"], "a3": ["a:foo:bar:b"]} + assert filtered.attributes == { + "a1": [], + "a2": ["foo:bar", "1:foo:bar:2"], + "a3": ["a:foo:bar:b"], + } def test_filter_one_attribute_from_all_target_providers_for_all_requesters(self): - attribute_filters = { - "": { - "": { - "a2": "^foo:bar$" - } - } - } + attribute_filters = {"": {"": {"a2": "^foo:bar$"}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) @@ -71,12 +63,7 @@ def test_filter_one_attribute_from_all_target_providers_for_all_requesters(self) def test_filter_one_attribute_from_all_target_providers_for_one_requester(self): requester = "test_requester" - attribute_filters = { - "": { - requester: - {"a1": "foo:bar"} - } - } + attribute_filters = {"": {requester: {"a1": "foo:bar"}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(auth_info=AuthenticationInformation()) @@ -88,12 +75,7 @@ def test_filter_one_attribute_from_all_target_providers_for_one_requester(self): assert filtered.attributes == {"a1": ["1:foo:bar:2"]} def test_filter_attribute_not_in_response(self): - attribute_filters = { - "": { - "": - {"a0": "foo:bar"} - } - } + attribute_filters = {"": {"": {"a0": "foo:bar"}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(auth_info=AuthenticationInformation()) @@ -105,12 +87,7 @@ def test_filter_attribute_not_in_response(self): def test_filter_one_attribute_for_one_target_provider(self): target_provider = "test_provider" - attribute_filters = { - target_provider: { - "": - {"a1": "foo:bar"} - } - } + attribute_filters = {target_provider: {"": {"a1": "foo:bar"}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(auth_info=AuthenticationInformation(issuer=target_provider)) @@ -123,12 +100,7 @@ def test_filter_one_attribute_for_one_target_provider(self): def test_filter_one_attribute_for_one_target_provider_for_one_requester(self): target_provider = "test_provider" requester = "test_requester" - attribute_filters = { - target_provider: { - requester: - {"a1": "foo:bar"} - } - } + attribute_filters = {target_provider: {requester: {"a1": "foo:bar"}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(auth_info=AuthenticationInformation(issuer=target_provider)) @@ -139,16 +111,10 @@ def test_filter_one_attribute_for_one_target_provider_for_one_requester(self): filtered = filter_service.process(None, resp) assert filtered.attributes == {"a1": ["1:foo:bar:2"]} - def test_filter_one_attribute_from_all_target_providers_for_all_requesters_in_extended_notation(self): - attribute_filters = { - "": { - "": { - "a2": { - "regexp": "^foo:bar$" - } - } - } - } + def test_filter_one_attribute_from_all_target_providers_for_all_requesters_in_extended_notation( + self, + ): + attribute_filters = {"": {"": {"a2": {"regexp": "^foo:bar$"}}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) @@ -160,15 +126,7 @@ def test_filter_one_attribute_from_all_target_providers_for_all_requesters_in_ex assert filtered.attributes == {"a1": ["abc:xyz"], "a2": ["foo:bar"]} def test_invalid_filter_type(self): - attribute_filters = { - "": { - "": { - "a2": { - "invalid_filter": None - } - } - } - } + attribute_filters = {"": {"": {"a2": {"invalid_filter": None}}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) @@ -180,15 +138,7 @@ def test_invalid_filter_type(self): filtered = filter_service.process(None, resp) def test_shibmdscope_match_value_filter_with_no_md_store_in_context(self): - attribute_filters = { - "": { - "": { - "a2": { - "shibmdscope_match_value": None - } - } - } - } + attribute_filters = {"": {"": {"a2": {"shibmdscope_match_value": None}}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) @@ -201,15 +151,7 @@ def test_shibmdscope_match_value_filter_with_no_md_store_in_context(self): assert filtered.attributes == {"a1": ["abc:xyz"], "a2": []} def test_shibmdscope_match_value_filter_with_empty_md_store_in_context(self): - attribute_filters = { - "": { - "": { - "a2": { - "shibmdscope_match_value": None - } - } - } - } + attribute_filters = {"": {"": {"a2": {"shibmdscope_match_value": None}}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) @@ -224,15 +166,7 @@ def test_shibmdscope_match_value_filter_with_empty_md_store_in_context(self): assert filtered.attributes == {"a1": ["abc:xyz"], "a2": []} def test_shibmdscope_match_value_filter_with_idp_md_with_no_scope(self): - attribute_filters = { - "": { - "": { - "a2": { - "shibmdscope_match_value": None - } - } - } - } + attribute_filters = {"": {"": {"a2": {"shibmdscope_match_value": None}}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) @@ -241,11 +175,13 @@ def test_shibmdscope_match_value_filter_with_idp_md_with_no_scope(self): "a2": ["foo.bar", "1.foo.bar.2"], } - idp_entityid = 'https://idp.example.org/' + idp_entityid = "https://idp.example.org/" resp.auth_info.issuer = idp_entityid mdstore = MetadataStore(None, Config()) - mdstore.imp(self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, None)) + mdstore.imp( + self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, None) + ) ctx = Context() ctx.decorate(Context.KEY_METADATA_STORE, mdstore) @@ -253,15 +189,7 @@ def test_shibmdscope_match_value_filter_with_idp_md_with_no_scope(self): assert filtered.attributes == {"a1": ["abc:xyz"], "a2": []} def test_shibmdscope_match_value_filter_with_idp_md_with_single_scope(self): - attribute_filters = { - "": { - "": { - "a2": { - "shibmdscope_match_value": None - } - } - } - } + attribute_filters = {"": {"": {"a2": {"shibmdscope_match_value": None}}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) @@ -270,11 +198,13 @@ def test_shibmdscope_match_value_filter_with_idp_md_with_single_scope(self): "a2": ["foo.bar", "1.foo.bar.2"], } - idp_entityid = 'https://idp.example.org/' + idp_entityid = "https://idp.example.org/" resp.auth_info.issuer = idp_entityid mdstore = MetadataStore(None, Config()) - mdstore.imp(self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, ["foo.bar"])) + mdstore.imp( + self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, ["foo.bar"]) + ) ctx = Context() ctx.decorate(Context.KEY_METADATA_STORE, mdstore) @@ -282,15 +212,7 @@ def test_shibmdscope_match_value_filter_with_idp_md_with_single_scope(self): assert filtered.attributes == {"a1": ["abc:xyz"], "a2": ["foo.bar"]} def test_shibmdscope_match_value_filter_with_idp_md_with_single_regexp_scope(self): - attribute_filters = { - "": { - "": { - "a2": { - "shibmdscope_match_value": None - } - } - } - } + attribute_filters = {"": {"": {"a2": {"shibmdscope_match_value": None}}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) @@ -299,13 +221,19 @@ def test_shibmdscope_match_value_filter_with_idp_md_with_single_regexp_scope(sel "a2": ["test.foo.bar", "1.foo.bar.2"], } - idp_entityid = 'https://idp.example.org/' + idp_entityid = "https://idp.example.org/" resp.auth_info.issuer = idp_entityid mdstore = MetadataStore(None, Config()) - mdstore.imp(self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, [r"[^.]*\.foo\.bar$"])) + mdstore.imp( + self.create_idp_metadata_conf_with_shibmd_scopes( + idp_entityid, [r"[^.]*\.foo\.bar$"] + ) + ) # mark scope as regexp (cannot be done via pysaml2 YAML config) - mdstore[idp_entityid]['idpsso_descriptor'][0]['extensions']['extension_elements'][0]['regexp'] = 'true' + mdstore[idp_entityid]["idpsso_descriptor"][0]["extensions"][ + "extension_elements" + ][0]["regexp"] = "true" ctx = Context() ctx.decorate(Context.KEY_METADATA_STORE, mdstore) @@ -313,15 +241,7 @@ def test_shibmdscope_match_value_filter_with_idp_md_with_single_regexp_scope(sel assert filtered.attributes == {"a1": ["abc:xyz"], "a2": ["test.foo.bar"]} def test_shibmdscope_match_value_filter_with_idp_md_with_multiple_scopes(self): - attribute_filters = { - "": { - "": { - "a2": { - "shibmdscope_match_value": None - } - } - } - } + attribute_filters = {"": {"": {"a2": {"shibmdscope_match_value": None}}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) @@ -330,11 +250,15 @@ def test_shibmdscope_match_value_filter_with_idp_md_with_multiple_scopes(self): "a2": ["foo.bar", "1.foo.bar.2", "foo.baz", "foo.baz.com"], } - idp_entityid = 'https://idp.example.org/' + idp_entityid = "https://idp.example.org/" resp.auth_info.issuer = idp_entityid mdstore = MetadataStore(None, Config()) - mdstore.imp(self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, ["foo.bar", "foo.baz"])) + mdstore.imp( + self.create_idp_metadata_conf_with_shibmd_scopes( + idp_entityid, ["foo.bar", "foo.baz"] + ) + ) ctx = Context() ctx.decorate(Context.KEY_METADATA_STORE, mdstore) @@ -342,28 +266,28 @@ def test_shibmdscope_match_value_filter_with_idp_md_with_multiple_scopes(self): assert filtered.attributes == {"a1": ["abc:xyz"], "a2": ["foo.bar", "foo.baz"]} def test_shibmdscope_match_scope_filter_with_single_scope(self): - attribute_filters = { - "": { - "": { - "a2": { - "shibmdscope_match_scope": None - } - } - } - } + attribute_filters = {"": {"": {"a2": {"shibmdscope_match_scope": None}}}} filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) resp.attributes = { "a1": ["abc:xyz"], - "a2": ["foo.bar", "value@foo.bar", "1.foo.bar.2", "value@foo.bar.2", "value@extra@foo.bar"], + "a2": [ + "foo.bar", + "value@foo.bar", + "1.foo.bar.2", + "value@foo.bar.2", + "value@extra@foo.bar", + ], } - idp_entityid = 'https://idp.example.org/' + idp_entityid = "https://idp.example.org/" resp.auth_info.issuer = idp_entityid mdstore = MetadataStore(None, Config()) - mdstore.imp(self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, ["foo.bar"])) + mdstore.imp( + self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, ["foo.bar"]) + ) ctx = Context() ctx.decorate(Context.KEY_METADATA_STORE, mdstore) @@ -372,28 +296,30 @@ def test_shibmdscope_match_scope_filter_with_single_scope(self): def test_multiple_filters_for_single_attribute(self): attribute_filters = { - "": { - "": { - "a2": { - "regexp": "^value1@", - "shibmdscope_match_scope": None - } - } - } + "": {"": {"a2": {"regexp": "^value1@", "shibmdscope_match_scope": None}}} } filter_service = self.create_filter_service(attribute_filters) resp = InternalData(AuthenticationInformation()) resp.attributes = { "a1": ["abc:xyz"], - "a2": ["foo.bar", "value1@foo.bar", "value2@foo.bar", "1.foo.bar.2", "value@foo.bar.2", "value@extra@foo.bar"], + "a2": [ + "foo.bar", + "value1@foo.bar", + "value2@foo.bar", + "1.foo.bar.2", + "value@foo.bar.2", + "value@extra@foo.bar", + ], } - idp_entityid = 'https://idp.example.org/' + idp_entityid = "https://idp.example.org/" resp.auth_info.issuer = idp_entityid mdstore = MetadataStore(None, Config()) - mdstore.imp(self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, ["foo.bar"])) + mdstore.imp( + self.create_idp_metadata_conf_with_shibmd_scopes(idp_entityid, ["foo.bar"]) + ) ctx = Context() ctx.decorate(Context.KEY_METADATA_STORE, mdstore) diff --git a/tests/satosa/micro_services/test_attribute_policy.py b/tests/satosa/micro_services/test_attribute_policy.py index f68483025..0ff756a44 100644 --- a/tests/satosa/micro_services/test_attribute_policy.py +++ b/tests/satosa/micro_services/test_attribute_policy.py @@ -8,7 +8,7 @@ def create_attribute_policy_service(self, attribute_policies): attribute_policy_service = AttributePolicy( config=attribute_policies, name="test_attribute_policy", - base_url="https://satosa.example.com" + base_url="https://satosa.example.com", ) attribute_policy_service.next = lambda ctx, data: data return attribute_policy_service @@ -18,9 +18,7 @@ def test_attribute_policy(self): attribute_policies = { "attribute_policy": { "requester_everything_allowed": {}, - "requester_nothing_allowed": { - "allowed": {} - }, + "requester_nothing_allowed": {"allowed": {}}, "requester_subset_allowed": { "allowed": { "attr1", @@ -29,11 +27,7 @@ def test_attribute_policy(self): }, }, } - attributes = { - "attr1": ["foo"], - "attr2": ["foo", "bar"], - "attr3": ["foo"] - } + attributes = {"attr1": ["foo"], "attr2": ["foo", "bar"], "attr3": ["foo"]} results = { "requester_everything_allowed": attributes.keys(), "requester_nothing_allowed": set(), @@ -41,7 +35,8 @@ def test_attribute_policy(self): } for requester, result in results.items(): attribute_policy_service = self.create_attribute_policy_service( - attribute_policies) + attribute_policies + ) ctx = Context() ctx.state = dict() @@ -51,8 +46,8 @@ def test_attribute_policy(self): resp.attributes = { "attr1": ["foo"], "attr2": ["foo", "bar"], - "attr3": ["foo"] + "attr3": ["foo"], } filtered = attribute_policy_service.process(ctx, resp) - assert(filtered.attributes.keys() == result) + assert filtered.attributes.keys() == result diff --git a/tests/satosa/micro_services/test_consent.py b/tests/satosa/micro_services/test_consent.py index a8eaed965..7eed56569 100644 --- a/tests/satosa/micro_services/test_consent.py +++ b/tests/satosa/micro_services/test_consent.py @@ -8,19 +8,21 @@ import responses from jwkest.jwk import RSAKey, rsa_load from jwkest.jws import JWS - from saml2.saml import NAMEID_FORMAT_PERSISTENT from satosa.context import Context -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.internal import AuthenticationInformation, InternalData from satosa.micro_services import consent from satosa.micro_services.consent import Consent, UnexpectedResponseError from satosa.response import Redirect FILTER = ["displayName", "co"] CONSENT_SERVICE_URL = "https://consent.example.com" -ATTRIBUTES = {"displayName": ["Test"], "co": ["example"], "sn": ["should be removed by consent filter"]} +ATTRIBUTES = { + "displayName": ["Test"], + "co": ["example"], + "sn": ["should be removed by consent filter"], +} USER_ID_ATTR = "user_id" @@ -36,9 +38,12 @@ def consent_config(self, signing_key_path): @pytest.fixture(autouse=True) def create_module(self, consent_config): - self.consent_module = Consent(consent_config, - internal_attributes={"attributes": {}, "user_id_to_attr": USER_ID_ATTR}, - name="Consent", base_url="https://satosa.example.com") + self.consent_module = Consent( + consent_config, + internal_attributes={"attributes": {}, "user_id_to_attr": USER_ID_ATTR}, + name="Consent", + base_url="https://satosa.example.com", + ) self.consent_module.next = lambda ctx, data: (ctx, data) @pytest.fixture @@ -72,7 +77,9 @@ def assert_redirect(self, redirect_resp, expected_ticket): path = urlparse(redirect_resp.message).path assert path == "/consent/" + expected_ticket - def assert_registration_req(self, request, internal_response, sign_key_path, base_url, requester_name): + def assert_registration_req( + self, request, internal_response, sign_key_path, base_url, requester_name + ): split_path = request.path_url.lstrip("/").split("/") assert len(split_path) == 2 jwks = split_path[1] @@ -92,102 +99,168 @@ def assert_registration_req(self, request, internal_response, sign_key_path, bas @responses.activate def test_verify_consent_false_on_http_400(self, consent_config): consent_id = "1234" - responses.add(responses.GET, - "{}/verify/{}".format(consent_config["api_url"], consent_id), - status=400) + responses.add( + responses.GET, + "{}/verify/{}".format(consent_config["api_url"], consent_id), + status=400, + ) assert not self.consent_module._verify_consent(consent_id) @responses.activate def test_verify_consent(self, consent_config): consent_id = "1234" - responses.add(responses.GET, - "{}/verify/{}".format(consent_config["api_url"], consent_id), - status=200, body=json.dumps(FILTER)) + responses.add( + responses.GET, + "{}/verify/{}".format(consent_config["api_url"], consent_id), + status=200, + body=json.dumps(FILTER), + ) assert self.consent_module._verify_consent(consent_id) == FILTER - @pytest.mark.parametrize('status', [ - 401, 404, 418, 500 - ]) + @pytest.mark.parametrize("status", [401, 404, 418, 500]) @responses.activate - def test_consent_registration_raises_on_unexpected_status_code(self, status, consent_config): - responses.add(responses.GET, re.compile(r"{}/creq/.*".format(consent_config["api_url"])), - status=status) + def test_consent_registration_raises_on_unexpected_status_code( + self, status, consent_config + ): + responses.add( + responses.GET, + re.compile(r"{}/creq/.*".format(consent_config["api_url"])), + status=status, + ) with pytest.raises(UnexpectedResponseError): self.consent_module._consent_registration({}) @responses.activate def test_consent_registration(self, consent_config): - responses.add(responses.GET, re.compile(r"{}/creq/.*".format(consent_config["api_url"])), - status=200, body="ticket") + responses.add( + responses.GET, + re.compile(r"{}/creq/.*".format(consent_config["api_url"])), + status=200, + body="ticket", + ) assert self.consent_module._consent_registration({}) == "ticket" @responses.activate - def test_consent_handles_connection_error(self, context, internal_response, internal_request, - consent_verify_endpoint_regex): - responses.add(responses.GET, - consent_verify_endpoint_regex, - body=requests.ConnectionError("No connection")) + def test_consent_handles_connection_error( + self, + context, + internal_response, + internal_request, + consent_verify_endpoint_regex, + ): + responses.add( + responses.GET, + consent_verify_endpoint_regex, + body=requests.ConnectionError("No connection"), + ) context.state[consent.STATE_KEY] = {"filter": []} with responses.RequestsMock(assert_all_requests_are_fired=True) as rsps: - rsps.add(responses.GET, - consent_verify_endpoint_regex, - body=requests.ConnectionError("No connection")) - context, internal_response = self.consent_module.process(context, internal_response) + rsps.add( + responses.GET, + consent_verify_endpoint_regex, + body=requests.ConnectionError("No connection"), + ) + context, internal_response = self.consent_module.process( + context, internal_response + ) assert context assert not internal_response.attributes @responses.activate - def test_consent_prev_given(self, context, internal_response, internal_request, - consent_verify_endpoint_regex): - responses.add(responses.GET, consent_verify_endpoint_regex, status=200, - body=json.dumps(FILTER)) + def test_consent_prev_given( + self, + context, + internal_response, + internal_request, + consent_verify_endpoint_regex, + ): + responses.add( + responses.GET, + consent_verify_endpoint_regex, + status=200, + body=json.dumps(FILTER), + ) context.state[consent.STATE_KEY] = {"filter": internal_request.attributes} - context, internal_response = self.consent_module.process(context, internal_response) + context, internal_response = self.consent_module.process( + context, internal_response + ) assert context assert "displayName" in internal_response.attributes - def test_consent_full_flow(self, context, consent_config, internal_response, internal_request, - consent_verify_endpoint_regex, consent_registration_endpoint_regex): + def test_consent_full_flow( + self, + context, + consent_config, + internal_response, + internal_request, + consent_verify_endpoint_regex, + consent_registration_endpoint_regex, + ): expected_ticket = "my_ticket" requester_name = internal_response.requester_name - context.state[consent.STATE_KEY] = {"filter": internal_request.attributes, - "requester_name": requester_name} + context.state[consent.STATE_KEY] = { + "filter": internal_request.attributes, + "requester_name": requester_name, + } with responses.RequestsMock() as rsps: rsps.add(responses.GET, consent_verify_endpoint_regex, status=401) - rsps.add(responses.GET, consent_registration_endpoint_regex, status=200, - body=expected_ticket) + rsps.add( + responses.GET, + consent_registration_endpoint_regex, + status=200, + body=expected_ticket, + ) resp = self.consent_module.process(context, internal_response) self.assert_redirect(resp, expected_ticket) - self.assert_registration_req(rsps.calls[1].request, - internal_response, - consent_config["sign_key"], - self.consent_module.base_url, - requester_name) + self.assert_registration_req( + rsps.calls[1].request, + internal_response, + consent_config["sign_key"], + self.consent_module.base_url, + requester_name, + ) with responses.RequestsMock() as rsps: # Now consent has been given, consent service returns 200 OK - rsps.add(responses.GET, consent_verify_endpoint_regex, status=200, - body=json.dumps(FILTER)) + rsps.add( + responses.GET, + consent_verify_endpoint_regex, + status=200, + body=json.dumps(FILTER), + ) - context, internal_response = self.consent_module._handle_consent_response(context) + context, internal_response = self.consent_module._handle_consent_response( + context + ) assert internal_response.attributes["displayName"] == ["Test"] assert internal_response.attributes["co"] == ["example"] assert "sn" not in internal_response.attributes # 'sn' should be filtered @responses.activate - def test_consent_not_given(self, context, consent_config, internal_response, internal_request, - consent_verify_endpoint_regex, consent_registration_endpoint_regex): + def test_consent_not_given( + self, + context, + consent_config, + internal_response, + internal_request, + consent_verify_endpoint_regex, + consent_registration_endpoint_regex, + ): expected_ticket = "my_ticket" responses.add(responses.GET, consent_verify_endpoint_regex, status=401) - responses.add(responses.GET, consent_registration_endpoint_regex, status=200, - body=expected_ticket) + responses.add( + responses.GET, + consent_registration_endpoint_regex, + status=200, + body=expected_ticket, + ) requester_name = internal_response.requester_name context.state[consent.STATE_KEY] = {} @@ -195,39 +268,54 @@ def test_consent_not_given(self, context, consent_config, internal_response, int resp = self.consent_module.process(context, internal_response) self.assert_redirect(resp, expected_ticket) - self.assert_registration_req(responses.calls[1].request, - internal_response, - consent_config["sign_key"], - self.consent_module.base_url, - requester_name) + self.assert_registration_req( + responses.calls[1].request, + internal_response, + consent_config["sign_key"], + self.consent_module.base_url, + requester_name, + ) new_context = Context() new_context.state = context.state # Verify endpoint of consent service still gives 401 (no consent given) - context, internal_response = self.consent_module._handle_consent_response(context) + context, internal_response = self.consent_module._handle_consent_response( + context + ) assert not internal_response.attributes def test_get_consent_id(self): attributes = {"foo": ["bar", "123"], "abc": ["xyz", "456"]} id = self.consent_module._get_consent_id("test-requester", "user1", attributes) - assert id == "ZTRhMTJmNWQ2Yjk2YWE0YzgyMzU4NTlmNjM3YjlhNmQ4ZjZiODMzOTQ0ZjNiMTVmODEwMDhmMDg5N2JlMDg0Y2ZkZGFkOTkzMDZiNDZiNjMxNzBkYzExOTcxN2RkMzJjMmY5NzRhZDA2NjYxMTg0NjkyYzdjN2IxNTRiZDkwNmM=" + assert ( + id + == "ZTRhMTJmNWQ2Yjk2YWE0YzgyMzU4NTlmNjM3YjlhNmQ4ZjZiODMzOTQ0ZjNiMTVmODEwMDhmMDg5N2JlMDg0Y2ZkZGFkOTkzMDZiNDZiNjMxNzBkYzExOTcxN2RkMzJjMmY5NzRhZDA2NjYxMTg0NjkyYzdjN2IxNTRiZDkwNmM=" + ) def test_filter_attributes(self): filtered_attributes = self.consent_module._filter_attributes(ATTRIBUTES, FILTER) assert Counter(filtered_attributes.keys()) == Counter(FILTER) @responses.activate - def test_manage_consent_without_filter_passes_through_all_attributes(self, context, internal_response, - consent_verify_endpoint_regex): + def test_manage_consent_without_filter_passes_through_all_attributes( + self, context, internal_response, consent_verify_endpoint_regex + ): # fake previous consent - responses.add(responses.GET, consent_verify_endpoint_regex, status=200, - body=json.dumps(list(internal_response.attributes.keys()))) + responses.add( + responses.GET, + consent_verify_endpoint_regex, + status=200, + body=json.dumps(list(internal_response.attributes.keys())), + ) - context.state[consent.STATE_KEY] = {"filter": []} # No filter + context.state[consent.STATE_KEY] = {"filter": []} # No filter self.consent_module.process(context, internal_response) consent_hash = urlparse(responses.calls[0].request.url).path.split("/")[2] - expected_hash = self.consent_module._get_consent_id(internal_response.requester, internal_response.subject_id, - internal_response.attributes) + expected_hash = self.consent_module._get_consent_id( + internal_response.requester, + internal_response.subject_id, + internal_response.attributes, + ) assert consent_hash == expected_hash diff --git a/tests/satosa/micro_services/test_custom_routing.py b/tests/satosa/micro_services/test_custom_routing.py index 1be124877..3c82a9825 100644 --- a/tests/satosa/micro_services/test_custom_routing.py +++ b/tests/satosa/micro_services/test_custom_routing.py @@ -4,13 +4,14 @@ import pytest from satosa.context import Context -from satosa.state import State -from satosa.exception import SATOSAError, SATOSAConfigurationError +from satosa.exception import SATOSAConfigurationError, SATOSAError from satosa.internal import InternalData -from satosa.micro_services.custom_routing import DecideIfRequesterIsAllowed -from satosa.micro_services.custom_routing import DecideBackendByTargetIssuer -from satosa.micro_services.custom_routing import DecideBackendByRequester - +from satosa.micro_services.custom_routing import ( + DecideBackendByRequester, + DecideBackendByTargetIssuer, + DecideIfRequesterIsAllowed, +) +from satosa.state import State TARGET_ENTITY = "entity1" @@ -25,8 +26,11 @@ def target_context(context): class TestDecideIfRequesterIsAllowed: def create_decide_service(self, rules): - decide_service = DecideIfRequesterIsAllowed(config=dict(rules=rules), name="test_decide_service", - base_url="https://satosa.example.com") + decide_service = DecideIfRequesterIsAllowed( + config=dict(rules=rules), + name="test_decide_service", + base_url="https://satosa.example.com", + ) decide_service.next = lambda ctx, data: data return decide_service @@ -45,10 +49,7 @@ def test_allow_one_requester(self, target_context): with pytest.raises(SATOSAError): decide_service.process(target_context, req) - @pytest.mark.parametrize("requester", [ - "test_requester", - "somebody else" - ]) + @pytest.mark.parametrize("requester", ["test_requester", "somebody else"]) def test_allow_all_requesters(self, target_context, requester): rules = { TARGET_ENTITY: { @@ -72,10 +73,7 @@ def test_deny_one_requester(self, target_context): with pytest.raises(SATOSAError): assert decide_service.process(target_context, req) - @pytest.mark.parametrize("requester", [ - "test_requester", - "somebody else" - ]) + @pytest.mark.parametrize("requester", ["test_requester", "somebody else"]) def test_deny_all_requesters(self, target_context, requester): rules = { TARGET_ENTITY: { @@ -124,10 +122,7 @@ def test_deny_takes_precedence_over_allow_all(self, target_context): req = InternalData(requester="somebody else") decide_service.process(target_context, req) - @pytest.mark.parametrize("requester", [ - "*", - "test_requester" - ]) + @pytest.mark.parametrize("requester", ["*", "test_requester"]) def test_deny_all_and_allow_all_should_raise_exception(self, requester): rules = { TARGET_ENTITY: { @@ -138,12 +133,10 @@ def test_deny_all_and_allow_all_should_raise_exception(self, requester): with pytest.raises(SATOSAConfigurationError): self.create_decide_service(rules) - def test_defaults_to_allow_all_requesters_for_target_entity_without_specific_rules(self, target_context): - rules = { - "some other entity": { - "allow": ["foobar"] - } - } + def test_defaults_to_allow_all_requesters_for_target_entity_without_specific_rules( + self, target_context + ): + rules = {"some other entity": {"allow": ["foobar"]}} decide_service = self.create_decide_service(rules) req = InternalData(requester="test_requester") @@ -169,16 +162,16 @@ def setUp(self): context.state = State() config = { - 'default_backend': 'default_backend', - 'target_mapping': { - 'mapped_idp.example.org': 'mapped_backend', + "default_backend": "default_backend", + "target_mapping": { + "mapped_idp.example.org": "mapped_backend", }, } plugin = DecideBackendByTargetIssuer( config=config, - name='test_decide_service', - base_url='https://satosa.example.org', + name="test_decide_service", + base_url="https://satosa.example.org", ) plugin.next = lambda ctx, data: (ctx, data) @@ -187,22 +180,22 @@ def setUp(self): self.plugin = plugin def test_when_target_is_not_set_do_skip(self): - data = InternalData(requester='test_requester') + data = InternalData(requester="test_requester") newctx, newdata = self.plugin.process(self.context, data) assert not newctx.target_backend def test_when_target_is_not_mapped_choose_default_backend(self): - self.context.decorate(Context.KEY_TARGET_ENTITYID, 'idp.example.org') - data = InternalData(requester='test_requester') + self.context.decorate(Context.KEY_TARGET_ENTITYID, "idp.example.org") + data = InternalData(requester="test_requester") newctx, newdata = self.plugin.process(self.context, data) - assert newctx.target_backend == 'default_backend' + assert newctx.target_backend == "default_backend" def test_when_target_is_mapped_choose_mapping_backend(self): - self.context.decorate(Context.KEY_TARGET_ENTITYID, 'mapped_idp.example.org') - data = InternalData(requester='test_requester') - data.requester = 'somebody else' + self.context.decorate(Context.KEY_TARGET_ENTITYID, "mapped_idp.example.org") + data = InternalData(requester="test_requester") + data.requester = "somebody else" newctx, newdata = self.plugin.process(self.context, data) - assert newctx.target_backend == 'mapped_backend' + assert newctx.target_backend == "mapped_backend" class TestDecideBackendByRequester(TestCase): @@ -211,15 +204,15 @@ def setUp(self): context.state = State() config = { - 'requester_mapping': { - 'test_requester': 'mapped_backend', + "requester_mapping": { + "test_requester": "mapped_backend", }, } plugin = DecideBackendByRequester( config=config, - name='test_decide_service', - base_url='https://satosa.example.org', + name="test_decide_service", + base_url="https://satosa.example.org", ) plugin.next = lambda ctx, data: (ctx, data) @@ -228,26 +221,26 @@ def setUp(self): self.plugin = plugin def test_when_requester_is_not_mapped_and_no_default_backend_skip(self): - data = InternalData(requester='other_test_requester') + data = InternalData(requester="other_test_requester") newctx, newdata = self.plugin.process(self.context, data) assert not newctx.target_backend def test_when_requester_is_not_mapped_choose_default_backend(self): # override config to set default backend - self.config['default_backend'] = 'default_backend' + self.config["default_backend"] = "default_backend" self.plugin = DecideBackendByRequester( config=self.config, - name='test_decide_service', - base_url='https://satosa.example.org', + name="test_decide_service", + base_url="https://satosa.example.org", ) self.plugin.next = lambda ctx, data: (ctx, data) - data = InternalData(requester='other_test_requester') + data = InternalData(requester="other_test_requester") newctx, newdata = self.plugin.process(self.context, data) - assert newctx.target_backend == 'default_backend' + assert newctx.target_backend == "default_backend" def test_when_requester_is_mapped_choose_mapping_backend(self): - data = InternalData(requester='test_requester') - data.requester = 'test_requester' + data = InternalData(requester="test_requester") + data.requester = "test_requester" newctx, newdata = self.plugin.process(self.context, data) - assert newctx.target_backend == 'mapped_backend' + assert newctx.target_backend == "mapped_backend" diff --git a/tests/satosa/micro_services/test_disco.py b/tests/satosa/micro_services/test_disco.py index ac2c3c5c2..7d7ffbcb1 100644 --- a/tests/satosa/micro_services/test_disco.py +++ b/tests/satosa/micro_services/test_disco.py @@ -3,9 +3,8 @@ import pytest from satosa.context import Context +from satosa.micro_services.disco import DiscoToTargetIssuer, DiscoToTargetIssuerError from satosa.state import State -from satosa.micro_services.disco import DiscoToTargetIssuer -from satosa.micro_services.disco import DiscoToTargetIssuerError class TestDiscoToTargetIssuer(TestCase): @@ -14,15 +13,15 @@ def setUp(self): context.state = State() config = { - 'disco_endpoints': [ - '.*/disco', + "disco_endpoints": [ + ".*/disco", ], } plugin = DiscoToTargetIssuer( config=config, - name='test_disco_to_target_issuer', - base_url='https://satosa.example.org', + name="test_disco_to_target_issuer", + base_url="https://satosa.example.org", ) plugin.next = lambda ctx, data: (ctx, data) @@ -36,9 +35,9 @@ def test_when_entity_id_is_not_set_raise_error(self): self.plugin._handle_disco_response(self.context) def test_when_entity_id_is_set_target_issuer_is_set(self): - entity_id = 'idp.example.org' + entity_id = "idp.example.org" self.context.request = { - 'entityID': entity_id, + "entityID": entity_id, } newctx, newdata = self.plugin._handle_disco_response(self.context) assert newctx.get_decoration(Context.KEY_TARGET_ENTITYID) == entity_id diff --git a/tests/satosa/micro_services/test_idp_hinting.py b/tests/satosa/micro_services/test_idp_hinting.py index 2fa454253..480837c66 100644 --- a/tests/satosa/micro_services/test_idp_hinting.py +++ b/tests/satosa/micro_services/test_idp_hinting.py @@ -2,8 +2,8 @@ from satosa.context import Context from satosa.internal import InternalData -from satosa.state import State from satosa.micro_services.idp_hinting import IdpHinting +from satosa.state import State class TestIdpHinting(TestCase): @@ -12,14 +12,12 @@ def setUp(self): context.state = State() internal_data = InternalData() - config = { - 'allowed_params': ["idp_hinting", "idp_hint", "idphint"] - } + config = {"allowed_params": ["idp_hinting", "idp_hint", "idphint"]} plugin = IdpHinting( config=config, - name='test_idphinting', - base_url='https://satosa.example.org', + name="test_idphinting", + base_url="https://satosa.example.org", ) plugin.next = lambda ctx, data: (ctx, data) @@ -34,22 +32,25 @@ def test_no_query_params(self): assert not new_context.get_decoration(Context.KEY_TARGET_ENTITYID) def test_hint_in_params(self): - _target = 'https://localhost:8080' - self.context.qs_params = {'idphint': _target} + _target = "https://localhost:8080" + self.context.qs_params = {"idphint": _target} new_context, new_data = self.plugin.process(self.context, self.data) assert new_context.get_decoration(Context.KEY_TARGET_ENTITYID) == _target def test_no_hint_in_params(self): - _target = 'https://localhost:8080' - self.context.qs_params = {'param_not_in_allowed_params': _target} + _target = "https://localhost:8080" + self.context.qs_params = {"param_not_in_allowed_params": _target} new_context, new_data = self.plugin.process(self.context, self.data) assert not new_context.get_decoration(Context.KEY_TARGET_ENTITYID) def test_issuer_already_set(self): - _pre_selected_target = 'https://local.localhost:8080' + _pre_selected_target = "https://local.localhost:8080" self.context.decorate(Context.KEY_TARGET_ENTITYID, _pre_selected_target) - _target = 'https://localhost:8080' - self.context.qs_params = {'idphint': _target} + _target = "https://localhost:8080" + self.context.qs_params = {"idphint": _target} new_context, new_data = self.plugin.process(self.context, self.data) - assert new_context.get_decoration(Context.KEY_TARGET_ENTITYID) == _pre_selected_target + assert ( + new_context.get_decoration(Context.KEY_TARGET_ENTITYID) + == _pre_selected_target + ) assert new_context.get_decoration(Context.KEY_TARGET_ENTITYID) != _target diff --git a/tests/satosa/micro_services/test_ldap_attribute_store.py b/tests/satosa/micro_services/test_ldap_attribute_store.py index e3af1a7f5..0277e50c1 100644 --- a/tests/satosa/micro_services/test_ldap_attribute_store.py +++ b/tests/satosa/micro_services/test_ldap_attribute_store.py @@ -1,76 +1,75 @@ -import pytest - +import logging from copy import deepcopy -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData -from satosa.micro_services.ldap_attribute_store import LdapAttributeStore +import pytest + from satosa.context import Context +from satosa.internal import AuthenticationInformation, InternalData +from satosa.micro_services.ldap_attribute_store import LdapAttributeStore -import logging logging.basicConfig(level=logging.DEBUG) + class TestLdapAttributeStore: ldap_attribute_store_config = { - 'default': { - 'auto_bind': 'AUTO_BIND_NO_TLS', - 'client_strategy': 'MOCK_SYNC', - 'ldap_url': 'ldap://satosa.example.com', - 'bind_dn': 'uid=readonly_user,ou=system,dc=example,dc=com', - 'bind_password': 'password', - 'search_base': 'ou=people,dc=example,dc=com', - 'query_return_attributes': [ - 'givenName', - 'sn', - 'mail', - 'employeeNumber' - ], - 'ldap_to_internal_map': { - 'givenName': 'givenname', - 'sn': 'sn', - 'mail': 'mail', - 'employeeNumber': 'employeenumber' + "default": { + "auto_bind": "AUTO_BIND_NO_TLS", + "client_strategy": "MOCK_SYNC", + "ldap_url": "ldap://satosa.example.com", + "bind_dn": "uid=readonly_user,ou=system,dc=example,dc=com", + "bind_password": "password", + "search_base": "ou=people,dc=example,dc=com", + "query_return_attributes": ["givenName", "sn", "mail", "employeeNumber"], + "ldap_to_internal_map": { + "givenName": "givenname", + "sn": "sn", + "mail": "mail", + "employeeNumber": "employeenumber", }, - 'clear_input_attributes': True, - 'ordered_identifier_candidates': [ - {'attribute_names': ['uid']} - ], - 'ldap_identifier_attribute': 'uid' + "clear_input_attributes": True, + "ordered_identifier_candidates": [{"attribute_names": ["uid"]}], + "ldap_identifier_attribute": "uid", } } ldap_person_records = [ - ['employeeNumber=1000,ou=people,dc=example,dc=com', { - 'employeeNumber': '1000', - 'cn': 'Jane Baxter', - 'givenName': 'Jane', - 'sn': 'Baxter', - 'uid': 'jbaxter', - 'mail': 'jbaxter@example.com' - } - ], - ['employeeNumber=1001,ou=people,dc=example,dc=com', { - 'employeeNumber': '1001', - 'cn': 'Booker Lawson', - 'givenName': 'Booker', - 'sn': 'Lawson', - 'uid': 'booker.lawson', - 'mail': 'blawson@example.com' - } - ], + [ + "employeeNumber=1000,ou=people,dc=example,dc=com", + { + "employeeNumber": "1000", + "cn": "Jane Baxter", + "givenName": "Jane", + "sn": "Baxter", + "uid": "jbaxter", + "mail": "jbaxter@example.com", + }, + ], + [ + "employeeNumber=1001,ou=people,dc=example,dc=com", + { + "employeeNumber": "1001", + "cn": "Booker Lawson", + "givenName": "Booker", + "sn": "Lawson", + "uid": "booker.lawson", + "mail": "blawson@example.com", + }, + ], ] @pytest.fixture def ldap_attribute_store(self): - store = LdapAttributeStore(self.ldap_attribute_store_config, - name="test_ldap_attribute_store", - base_url="https://satosa.example.com") + store = LdapAttributeStore( + self.ldap_attribute_store_config, + name="test_ldap_attribute_store", + base_url="https://satosa.example.com", + ) # Mock up the 'next' microservice to be called. store.next = lambda ctx, data: data # We need to explicitly bind when using the MOCK_SYNC client strategy. - connection = store.config['default']['connection'] + connection = store.config["default"]["connection"] connection.bind() # Populate example records. @@ -81,8 +80,9 @@ def ldap_attribute_store(self): return store def test_attributes_general(self, ldap_attribute_store): - ldap_to_internal_map = (self.ldap_attribute_store_config['default'] - ['ldap_to_internal_map']) + ldap_to_internal_map = self.ldap_attribute_store_config["default"][ + "ldap_to_internal_map" + ] for dn, attributes in self.ldap_person_records: # Mock up the internal response the LDAP attribute store is @@ -91,8 +91,8 @@ def test_attributes_general(self, ldap_attribute_store): # The LDAP attribute store configuration and the mock records # expect to use a LDAP search filter for the uid attribute. - uid = attributes['uid'] - response.attributes = {'uid': uid} + uid = attributes["uid"] + response.attributes = {"uid": uid} context = Context() context.state = dict() @@ -106,4 +106,4 @@ def test_attributes_general(self, ldap_attribute_store): if ldap_attr in ldap_to_internal_map: internal_attr = ldap_to_internal_map[ldap_attr] response_attr = response.attributes[internal_attr] - assert(ldap_value in response_attr) + assert ldap_value in response_attr diff --git a/tests/satosa/scripts/test_satosa_saml_metadata.py b/tests/satosa/scripts/test_satosa_saml_metadata.py index f76f5d990..3ea13ddd9 100644 --- a/tests/satosa/scripts/test_satosa_saml_metadata.py +++ b/tests/satosa/scripts/test_satosa_saml_metadata.py @@ -19,67 +19,111 @@ def oidc_frontend_config(signing_key_path): "issuer": "https://proxy-op.example.com", "signing_key_path": signing_key_path, "provider": {"response_types_supported": ["id_token"]}, - } + }, } return data -@mongomock.patch(servers=(('localhost', 27017),)) +@mongomock.patch(servers=(("localhost", 27017),)) class TestConstructSAMLMetadata: - def test_saml_saml(self, tmpdir, cert_and_key, satosa_config_dict, saml_frontend_config, - saml_backend_config): + def test_saml_saml( + self, + tmpdir, + cert_and_key, + satosa_config_dict, + saml_frontend_config, + saml_backend_config, + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] - create_and_write_saml_metadata(satosa_config_dict, cert_and_key[1], cert_and_key[0], str(tmpdir), None) + create_and_write_saml_metadata( + satosa_config_dict, cert_and_key[1], cert_and_key[0], str(tmpdir), None + ) conf = Config() conf.cert_file = cert_and_key[0] security_ctx = security_context(conf) metadata_files = ["frontend.xml", "backend.xml"] for file in metadata_files: - md = MetaDataFile(None, os.path.join(str(tmpdir), file), security=security_ctx) + md = MetaDataFile( + None, os.path.join(str(tmpdir), file), security=security_ctx + ) assert md.load() - def test_saml_oidc(self, tmpdir, cert_and_key, satosa_config_dict, saml_frontend_config, - oidc_backend_config): + def test_saml_oidc( + self, + tmpdir, + cert_and_key, + satosa_config_dict, + saml_frontend_config, + oidc_backend_config, + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [oidc_backend_config] - create_and_write_saml_metadata(satosa_config_dict, cert_and_key[1], cert_and_key[0], str(tmpdir), None) + create_and_write_saml_metadata( + satosa_config_dict, cert_and_key[1], cert_and_key[0], str(tmpdir), None + ) conf = Config() conf.cert_file = cert_and_key[0] security_ctx = security_context(conf) - md = MetaDataFile(None, os.path.join(str(tmpdir), "frontend.xml"), security=security_ctx) + md = MetaDataFile( + None, os.path.join(str(tmpdir), "frontend.xml"), security=security_ctx + ) assert md.load() assert not os.path.isfile(os.path.join(str(tmpdir), "backend.xml")) - def test_oidc_saml(self, tmpdir, cert_and_key, satosa_config_dict, oidc_frontend_config, - saml_backend_config): + def test_oidc_saml( + self, + tmpdir, + cert_and_key, + satosa_config_dict, + oidc_frontend_config, + saml_backend_config, + ): satosa_config_dict["FRONTEND_MODULES"] = [oidc_frontend_config] satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config] - create_and_write_saml_metadata(satosa_config_dict, cert_and_key[1], cert_and_key[0], str(tmpdir), None) + create_and_write_saml_metadata( + satosa_config_dict, cert_and_key[1], cert_and_key[0], str(tmpdir), None + ) conf = Config() conf.cert_file = cert_and_key[0] security_ctx = security_context(conf) - md = MetaDataFile(None, os.path.join(str(tmpdir), "backend.xml"), security=security_ctx) + md = MetaDataFile( + None, os.path.join(str(tmpdir), "backend.xml"), security=security_ctx + ) assert md.load() assert not os.path.isfile(os.path.join(str(tmpdir), "frontend.xml")) - def test_split_frontend_metadata_to_separate_files(self, tmpdir, cert_and_key, satosa_config_dict, - saml_mirror_frontend_config, saml_backend_config, - oidc_backend_config): - + def test_split_frontend_metadata_to_separate_files( + self, + tmpdir, + cert_and_key, + satosa_config_dict, + saml_mirror_frontend_config, + saml_backend_config, + oidc_backend_config, + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_mirror_frontend_config] - satosa_config_dict["BACKEND_MODULES"] = [oidc_backend_config, saml_backend_config] - - create_and_write_saml_metadata(satosa_config_dict, cert_and_key[1], cert_and_key[0], str(tmpdir), None, - split_frontend_metadata=True) + satosa_config_dict["BACKEND_MODULES"] = [ + oidc_backend_config, + saml_backend_config, + ] + + create_and_write_saml_metadata( + satosa_config_dict, + cert_and_key[1], + cert_and_key[0], + str(tmpdir), + None, + split_frontend_metadata=True, + ) conf = Config() conf.cert_file = cert_and_key[0] @@ -92,20 +136,41 @@ def test_split_frontend_metadata_to_separate_files(self, tmpdir, cert_and_key, s md = MetaDataFile(None, file, security=security_ctx) assert md.load() - def test_split_backend_metadata_to_separate_files(self, tmpdir, cert_and_key, satosa_config_dict, - saml_frontend_config, saml_backend_config): - + def test_split_backend_metadata_to_separate_files( + self, + tmpdir, + cert_and_key, + satosa_config_dict, + saml_frontend_config, + saml_backend_config, + ): satosa_config_dict["FRONTEND_MODULES"] = [saml_frontend_config] - satosa_config_dict["BACKEND_MODULES"] = [saml_backend_config, saml_backend_config] - - create_and_write_saml_metadata(satosa_config_dict, cert_and_key[1], cert_and_key[0], str(tmpdir), None, - split_backend_metadata=True) + satosa_config_dict["BACKEND_MODULES"] = [ + saml_backend_config, + saml_backend_config, + ] + + create_and_write_saml_metadata( + satosa_config_dict, + cert_and_key[1], + cert_and_key[0], + str(tmpdir), + None, + split_backend_metadata=True, + ) conf = Config() conf.cert_file = cert_and_key[0] security_ctx = security_context(conf) - written_metadata_files = [saml_backend_config["name"], saml_backend_config["name"]] + written_metadata_files = [ + saml_backend_config["name"], + saml_backend_config["name"], + ] for file in written_metadata_files: - md = MetaDataFile(None, os.path.join(str(tmpdir), "{}_0.xml".format(file)), security=security_ctx) + md = MetaDataFile( + None, + os.path.join(str(tmpdir), "{}_0.xml".format(file)), + security=security_ctx, + ) assert md.load() diff --git a/tests/satosa/test_attribute_mapping.py b/tests/satosa/test_attribute_mapping.py index 93a3dff78..b42d7a451 100644 --- a/tests/satosa/test_attribute_mapping.py +++ b/tests/satosa/test_attribute_mapping.py @@ -9,12 +9,8 @@ class TestAttributeMapperNestedDataDifferentAttrProfile: def test_nested_mapping_nested_data_to_internal(self): mapping = { "attributes": { - "name": { - "openid": ["name"] - }, - "givenname": { - "openid": ["given_name", "name.firstName"] - }, + "name": {"openid": ["name"]}, + "givenname": {"openid": ["given_name", "name.firstName"]}, }, } @@ -31,16 +27,11 @@ def test_nested_mapping_nested_data_to_internal(self): assert internal_repr["name"] == [data["name"]] assert internal_repr["givenname"] == [data["name"]["firstName"]] - def test_nested_mapping_simple_data_to_internal(self): mapping = { "attributes": { - "name": { - "openid": ["name"] - }, - "givenname": { - "openid": ["given_name", "name.firstName"] - }, + "name": {"openid": ["name"]}, + "givenname": {"openid": ["given_name", "name.firstName"]}, }, } @@ -99,10 +90,7 @@ def test_deeply_nested_attribute_to_internal(self): def test_mapping_from_nested_attribute(self): mapping = { "attributes": { - "address": { - "openid": ["address.formatted"], - "saml": ["postaladdress"] - }, + "address": {"openid": ["address.formatted"], "saml": ["postaladdress"]}, }, } @@ -122,7 +110,7 @@ def test_mapping_from_deeply_nested_attribute(self): "attributes": { "address": { "openid": ["address.formatted.text.value"], - "saml": ["postaladdress"] + "saml": ["postaladdress"], }, }, } @@ -140,21 +128,19 @@ def test_mapping_from_deeply_nested_attribute(self): converter = AttributeMapper(mapping) internal_repr = converter.to_internal("openid", data) external_repr = converter.from_internal("saml", internal_repr) - assert external_repr["postaladdress"] == data["address"]["formatted"]["text"]["value"] + assert ( + external_repr["postaladdress"] + == data["address"]["formatted"]["text"]["value"] + ) def test_mapping_to_nested_attribute(self): mapping = { "attributes": { - "address": { - "openid": ["address.formatted"], - "saml": ["postaladdress"] - }, + "address": {"openid": ["address.formatted"], "saml": ["postaladdress"]}, }, } - data = { - "postaladdress": ["100 Universal City Plaza, Hollywood CA 91608, USA"] - } + data = {"postaladdress": ["100 Universal City Plaza, Hollywood CA 91608, USA"]} converter = AttributeMapper(mapping) internal_repr = converter.to_internal("saml", data) @@ -166,26 +152,25 @@ def test_mapping_to_deeply_nested_attribute(self): "attributes": { "address": { "openid": ["address.formatted.text.value"], - "saml": ["postaladdress"] + "saml": ["postaladdress"], }, }, } - data = { - "postaladdress": ["100 Universal City Plaza, Hollywood CA 91608, USA"] - } + data = {"postaladdress": ["100 Universal City Plaza, Hollywood CA 91608, USA"]} converter = AttributeMapper(mapping) internal_repr = converter.to_internal("saml", data) external_repr = converter.from_internal("openid", internal_repr) - assert external_repr["address"]["formatted"]["text"]["value"] == data["postaladdress"] + assert ( + external_repr["address"]["formatted"]["text"]["value"] + == data["postaladdress"] + ) def test_multiple_source_attribute_values(self): mapping = { "attributes": { - "mail": { - "saml": ["mail", "emailAddress", "email"] - }, + "mail": {"saml": ["mail", "emailAddress", "email"]}, }, } @@ -195,13 +180,17 @@ def test_multiple_source_attribute_values(self): "emailAddress": ["test3@example.com"], } - expected = Counter(["test1@example.com", "test2@example.com", "test3@example.com"]) + expected = Counter( + ["test1@example.com", "test2@example.com", "test3@example.com"] + ) converter = AttributeMapper(mapping) internal_repr = converter.to_internal("saml", data) assert Counter(internal_repr["mail"]) == expected external_repr = converter.from_internal("saml", internal_repr) - assert Counter(external_repr[mapping["attributes"]["mail"]["saml"][0]]) == expected + assert ( + Counter(external_repr[mapping["attributes"]["mail"]["saml"][0]]) == expected + ) def test_to_internal_filter(self): mapping = { @@ -246,7 +235,10 @@ def test_map_one_source_attribute_to_multiple_internal_attributes(self): converter = AttributeMapper(mapping) internal_repr = converter.to_internal("p1", {"email": ["test@example.com"]}) - assert internal_repr == {"mail": ["test@example.com"], "identifier": ["test@example.com"]} + assert internal_repr == { + "mail": ["test@example.com"], + "identifier": ["test@example.com"], + } def test_to_internal_profile_missing_attribute_mapping(self): mapping = { @@ -257,13 +249,17 @@ def test_to_internal_profile_missing_attribute_mapping(self): "id": { "foo": ["id"], "bar": ["uid"], - } + }, }, } converter = AttributeMapper(mapping) - internal_repr = converter.to_internal("bar", {"email": ["test@example.com"], "uid": ["uid"]}) - assert "mail" not in internal_repr # no mapping for the 'mail' attribute in the 'bar' profile + internal_repr = converter.to_internal( + "bar", {"email": ["test@example.com"], "uid": ["uid"]} + ) + assert ( + "mail" not in internal_repr + ) # no mapping for the 'mail' attribute in the 'bar' profile assert internal_repr["id"] == ["uid"] def test_to_internal_filter_profile_missing_attribute_mapping(self): @@ -275,13 +271,15 @@ def test_to_internal_filter_profile_missing_attribute_mapping(self): "id": { "foo": ["id"], "bar": ["uid"], - } + }, }, } converter = AttributeMapper(mapping) filter = converter.to_internal_filter("bar", ["email", "uid"]) - assert filter == ["id"] # mail should not included since its missing in 'bar' profile + assert filter == [ + "id" + ] # mail should not included since its missing in 'bar' profile def test_to_internal_with_unknown_attribute_profile(self): mapping = { @@ -325,28 +323,19 @@ def test_from_internal_with_unknown_profile(self): def test_simple_template_mapping(self): mapping = { "attributes": { - "last_name": { - "p1": ["sn"], - "p2": ["sn"] - }, - "first_name": { - "p1": ["givenName"], - "p2": ["givenName"] - }, - "name": { - "p2": ["cn"] - } - + "last_name": {"p1": ["sn"], "p2": ["sn"]}, + "first_name": {"p1": ["givenName"], "p2": ["givenName"]}, + "name": {"p2": ["cn"]}, }, "template_attributes": { - "name": { - "p2": ["${first_name[0]} ${last_name[0]}"] - } - } + "name": {"p2": ["${first_name[0]} ${last_name[0]}"]} + }, } converter = AttributeMapper(mapping) - internal_repr = converter.to_internal("p2", {"givenName": ["Valfrid"], "sn": ["Lindeman"]}) + internal_repr = converter.to_internal( + "p2", {"givenName": ["Valfrid"], "sn": ["Lindeman"]} + ) assert "name" in internal_repr assert len(internal_repr["name"]) == 1 assert internal_repr["name"][0] == "Valfrid Lindeman" @@ -356,27 +345,25 @@ def test_simple_template_mapping(self): def test_scoped_template_mapping(self): mapping = { "attributes": { - "unscoped_affiliation": { - "p1": ["eduPersonAffiliation"] - }, + "unscoped_affiliation": {"p1": ["eduPersonAffiliation"]}, "uid": { "p1": ["eduPersonPrincipalName"], }, - "affiliation": { - "p1": ["eduPersonScopedAffiliation"] - } + "affiliation": {"p1": ["eduPersonScopedAffiliation"]}, }, "template_attributes": { - "affiliation": { - "p1": ["${unscoped_affiliation[0]}@${uid[0] | scope}"] - } - } + "affiliation": {"p1": ["${unscoped_affiliation[0]}@${uid[0] | scope}"]} + }, } converter = AttributeMapper(mapping) - internal_repr = converter.to_internal("p1", { - "eduPersonAffiliation": ["student"], - "eduPersonPrincipalName": ["valfrid@lindeman.com"]}) + internal_repr = converter.to_internal( + "p1", + { + "eduPersonAffiliation": ["student"], + "eduPersonPrincipalName": ["valfrid@lindeman.com"], + }, + ) assert "affiliation" in internal_repr assert len(internal_repr["affiliation"]) == 1 assert internal_repr["affiliation"][0] == "student@lindeman.com" @@ -390,28 +377,24 @@ def test_template_attribute_overrides_existing_attribute(self): "first_name": { "p1": ["givenName"], }, - "name": { - "p1": ["cn"] - } + "name": {"p1": ["cn"]}, }, "template_attributes": { - "name": { - "p1": ["${first_name[0]} ${last_name[0]}"] - } - } + "name": {"p1": ["${first_name[0]} ${last_name[0]}"]} + }, } converter = AttributeMapper(mapping) - data = {"sn": ["Surname"], - "givenName": ["Given"], - "cn": ["Common Name"]} + data = {"sn": ["Surname"], "givenName": ["Given"], "cn": ["Common Name"]} internal_repr = converter.to_internal("p1", data) external_repr = converter.from_internal("p1", internal_repr) assert len(internal_repr["name"]) == 1 assert internal_repr["name"][0] == "Given Surname" assert external_repr["cn"][0] == "Given Surname" - def test_template_attribute_preserves_existing_attribute_if_template_cant_be_rendered(self): + def test_template_attribute_preserves_existing_attribute_if_template_cant_be_rendered( + self, + ): mapping = { "attributes": { "last_name": { @@ -420,21 +403,13 @@ def test_template_attribute_preserves_existing_attribute_if_template_cant_be_ren "first_name": { "p1": ["givenName"], }, - "name": { - "p1": ["cn"] - } + "name": {"p1": ["cn"]}, }, - "template_attributes": { - "name": { - "p1": ["${unknown[0]} ${last_name[0]}"] - } - } + "template_attributes": {"name": {"p1": ["${unknown[0]} ${last_name[0]}"]}}, } converter = AttributeMapper(mapping) - data = {"sn": ["Surname"], - "givenName": ["Given"], - "cn": ["Common Name"]} + data = {"sn": ["Surname"], "givenName": ["Given"], "cn": ["Common Name"]} internal_repr = converter.to_internal("p1", data) assert len(internal_repr["name"]) == 1 assert internal_repr["name"][0] == "Common Name" @@ -448,22 +423,26 @@ def test_template_attribute_with_multiple_templates_tries_them_all_templates(sel "first_name": { "p1": ["givenName"], }, - "name": { - "p1": ["cn"] - } + "name": {"p1": ["cn"]}, }, "template_attributes": { "name": { - "p1": ["${first_name[0]} ${last_name[0]}", "${unknown[0]} ${unknown[1]}", - "${first_name[1]} ${last_name[1]}", "${foo} ${bar}"] + "p1": [ + "${first_name[0]} ${last_name[0]}", + "${unknown[0]} ${unknown[1]}", + "${first_name[1]} ${last_name[1]}", + "${foo} ${bar}", + ] } - } + }, } converter = AttributeMapper(mapping) - data = {"sn": ["Surname1", "Surname2"], - "givenName": ["Given1", "Given2"], - "cn": ["Common Name"]} + data = { + "sn": ["Surname1", "Surname2"], + "givenName": ["Given1", "Given2"], + "cn": ["Common Name"], + } internal_repr = converter.to_internal("p1", data) assert len(internal_repr["name"]) == 2 assert internal_repr["name"][0] == "Given1 Surname1" @@ -478,26 +457,24 @@ def test_template_attribute_fail_does_not_insert_None_attribute_value(self): "first_name": { "p1": ["givenName"], }, - "name": { - "p1": ["cn"] - } + "name": {"p1": ["cn"]}, }, "template_attributes": { - "name": { - "p1": ["${first_name[0]} ${last_name[0]}"] - } - } + "name": {"p1": ["${first_name[0]} ${last_name[0]}"]} + }, } converter = AttributeMapper(mapping) internal_repr = converter.to_internal("p1", {}) assert len(internal_repr) == 0 - @pytest.mark.parametrize("attribute_value", [ - {"email": "test@example.com"}, - {"email": ["test@example.com"]} - ]) - def test_to_internal_same_attribute_value_from_list_and_single_value(self, attribute_value): + @pytest.mark.parametrize( + "attribute_value", + [{"email": "test@example.com"}, {"email": ["test@example.com"]}], + ) + def test_to_internal_same_attribute_value_from_list_and_single_value( + self, attribute_value + ): mapping = { "attributes": { "mail": { diff --git a/tests/satosa/test_base.py b/tests/satosa/test_base.py index 0f2a35f50..29ec3284f 100644 --- a/tests/satosa/test_base.py +++ b/tests/satosa/test_base.py @@ -4,8 +4,7 @@ import satosa from satosa.base import SATOSABase -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.internal import AuthenticationInformation, InternalData from satosa.satosa_config import SATOSAConfig @@ -22,38 +21,52 @@ def test_full_initialisation(self, satosa_config): assert len(base.request_micro_services) == 1 assert len(base.response_micro_services) == 1 - def test_auth_resp_callback_func_user_id_from_attrs_is_used_to_override_user_id(self, context, satosa_config): - satosa_config["INTERNAL_ATTRIBUTES"]["user_id_from_attrs"] = ["user_id", "domain"] + def test_auth_resp_callback_func_user_id_from_attrs_is_used_to_override_user_id( + self, context, satosa_config + ): + satosa_config["INTERNAL_ATTRIBUTES"]["user_id_from_attrs"] = [ + "user_id", + "domain", + ] base = SATOSABase(satosa_config) internal_resp = InternalData(auth_info=AuthenticationInformation("", "", "")) internal_resp.attributes = {"user_id": ["user"], "domain": ["@example.com"]} internal_resp.requester = "test_requester" context.state[satosa.base.STATE_KEY] = {"requester": "test_requester"} - context.state[satosa.routing.STATE_KEY] = satosa_config["FRONTEND_MODULES"][0]["name"] + context.state[satosa.routing.STATE_KEY] = satosa_config["FRONTEND_MODULES"][0][ + "name" + ] base._auth_resp_callback_func(context, internal_resp) expected_user_id = "user@example.com" assert internal_resp.subject_id == expected_user_id - def test_auth_resp_callback_func_respects_user_id_to_attr(self, context, satosa_config): + def test_auth_resp_callback_func_respects_user_id_to_attr( + self, context, satosa_config + ): satosa_config["INTERNAL_ATTRIBUTES"]["user_id_to_attr"] = "user_id" base = SATOSABase(satosa_config) internal_resp = InternalData(auth_info=AuthenticationInformation("", "", "")) internal_resp.subject_id = "user1234" context.state[satosa.base.STATE_KEY] = {"requester": "test_requester"} - context.state[satosa.routing.STATE_KEY] = satosa_config["FRONTEND_MODULES"][0]["name"] + context.state[satosa.routing.STATE_KEY] = satosa_config["FRONTEND_MODULES"][0][ + "name" + ] base._auth_resp_callback_func(context, internal_resp) assert internal_resp.attributes["user_id"] == [internal_resp.subject_id] - @pytest.mark.parametrize("micro_services", [ - [Mock()], - [Mock(), Mock()], - [Mock(), Mock(), Mock()], - ]) + @pytest.mark.parametrize( + "micro_services", + [ + [Mock()], + [Mock(), Mock()], + [Mock(), Mock(), Mock()], + ], + ) def test_link_micro_services(self, satosa_config, micro_services): base = SATOSABase(satosa_config) finish_callable = Mock() @@ -63,11 +76,10 @@ def test_link_micro_services(self, satosa_config, micro_services): assert micro_services[i].next == micro_services[i + 1].process assert micro_services[-1].next == finish_callable - @pytest.mark.parametrize("micro_services", [ - [], - None - ]) - def test_link_micro_services_with_invalid_input(self, satosa_config, micro_services): + @pytest.mark.parametrize("micro_services", [[], None]) + def test_link_micro_services_with_invalid_input( + self, satosa_config, micro_services + ): base = SATOSABase(satosa_config) finish_callable = Mock() # should not raise exception diff --git a/tests/satosa/test_plugin_loader.py b/tests/satosa/test_plugin_loader.py index ef11c961b..7eb48d3f5 100644 --- a/tests/satosa/test_plugin_loader.py +++ b/tests/satosa/test_plugin_loader.py @@ -7,7 +7,13 @@ from satosa.exception import SATOSAConfigurationError from satosa.frontends.base import FrontendModule from satosa.micro_services.base import RequestMicroService, ResponseMicroService -from satosa.plugin_loader import backend_filter, frontend_filter, _request_micro_service_filter, _response_micro_service_filter, _load_plugin_config +from satosa.plugin_loader import ( + _load_plugin_config, + _request_micro_service_filter, + _response_micro_service_filter, + backend_filter, + frontend_filter, +) class TestFilters(object): diff --git a/tests/satosa/test_response.py b/tests/satosa/test_response.py index 49836a734..32907adf2 100644 --- a/tests/satosa/test_response.py +++ b/tests/satosa/test_response.py @@ -9,10 +9,11 @@ def test_constructor_adding_content_type_header(self): headers = dict(resp.headers) assert headers["Content-Type"] == "bar" - @pytest.mark.parametrize("data, expected", [ - ("foobar", ["foobar"]), - (["foobar"], ["foobar"]) - ]) - def test_call_should_always_return_flat_list_to_comply_with_wsgi(self, data, expected): + @pytest.mark.parametrize( + "data, expected", [("foobar", ["foobar"]), (["foobar"], ["foobar"])] + ) + def test_call_should_always_return_flat_list_to_comply_with_wsgi( + self, data, expected + ): resp = Response(data) assert resp({}, lambda x, y: None) == expected diff --git a/tests/satosa/test_routing.py b/tests/satosa/test_routing.py index be23456ad..011582f21 100644 --- a/tests/satosa/test_routing.py +++ b/tests/satosa/test_routing.py @@ -2,7 +2,12 @@ from satosa.context import Context from satosa.routing import ModuleRouter, SATOSANoBoundEndpointError -from tests.util import TestBackend, TestFrontend, TestRequestMicroservice, TestResponseMicroservice +from tests.util import ( + TestBackend, + TestFrontend, + TestRequestMicroservice, + TestResponseMicroservice, +) FRONTEND_NAMES = ["Saml2IDP", "VOPaaSSaml2IDP"] BACKEND_NAMES = ["Saml2SP", "VOPaaSSaml2SP"] @@ -17,30 +22,44 @@ def create_router(self): frontends = [] for receiver in FRONTEND_NAMES: - frontends.append(TestFrontend(None, {"attributes": {}}, None, None, receiver)) + frontends.append( + TestFrontend(None, {"attributes": {}}, None, None, receiver) + ) request_micro_service_name = "RequestService" response_micro_service_name = "ResponseService" - microservices = [TestRequestMicroservice(request_micro_service_name, base_url="https://satosa.example.com"), - TestResponseMicroservice(response_micro_service_name, base_url="https://satosa.example.com")] + microservices = [ + TestRequestMicroservice( + request_micro_service_name, base_url="https://satosa.example.com" + ), + TestResponseMicroservice( + response_micro_service_name, base_url="https://satosa.example.com" + ), + ] self.router = ModuleRouter(frontends, backends, microservices) - @pytest.mark.parametrize('url_path, expected_frontend, expected_backend', [ - ("%s/%s/request" % (provider, receiver), receiver, provider) - for receiver in FRONTEND_NAMES - for provider in BACKEND_NAMES - ]) - def test_endpoint_routing_to_frontend(self, url_path, expected_frontend, expected_backend): + @pytest.mark.parametrize( + "url_path, expected_frontend, expected_backend", + [ + ("%s/%s/request" % (provider, receiver), receiver, provider) + for receiver in FRONTEND_NAMES + for provider in BACKEND_NAMES + ], + ) + def test_endpoint_routing_to_frontend( + self, url_path, expected_frontend, expected_backend + ): context = Context() context.path = url_path self.router.endpoint_routing(context) assert context.target_frontend == expected_frontend assert context.target_backend == expected_backend - @pytest.mark.parametrize('url_path, expected_backend', [ - ("%s/response" % (provider,), provider) for provider in BACKEND_NAMES - ]) + @pytest.mark.parametrize( + "url_path, expected_backend", + [("%s/response" % (provider,), provider) for provider in BACKEND_NAMES], + ) def test_endpoint_routing_to_backend(self, url_path, expected_backend): context = Context() context.path = url_path @@ -48,25 +67,36 @@ def test_endpoint_routing_to_backend(self, url_path, expected_backend): assert context.target_backend == expected_backend assert context.target_frontend is None - @pytest.mark.parametrize('url_path, expected_micro_service', [ - ("request_microservice/callback", "RequestService"), - ("response_microservice/callback", "ResponseService") - ]) + @pytest.mark.parametrize( + "url_path, expected_micro_service", + [ + ("request_microservice/callback", "RequestService"), + ("response_microservice/callback", "ResponseService"), + ], + ) def test_endpoint_routing_to_microservice(self, url_path, expected_micro_service): context = Context() context.path = url_path microservice_callable = self.router.endpoint_routing(context) assert context.target_micro_service == expected_micro_service - assert microservice_callable == self.router.micro_services[expected_micro_service]["instance"].callback + assert ( + microservice_callable + == self.router.micro_services[expected_micro_service]["instance"].callback + ) assert context.target_backend is None assert context.target_frontend is None - @pytest.mark.parametrize('url_path, expected_frontend, expected_backend', [ - ("%s/%s/request" % (provider, receiver), receiver, provider) - for receiver in FRONTEND_NAMES - for provider in BACKEND_NAMES - ]) - def test_module_routing(self, url_path, expected_frontend, expected_backend, context): + @pytest.mark.parametrize( + "url_path, expected_frontend, expected_backend", + [ + ("%s/%s/request" % (provider, receiver), receiver, provider) + for receiver in FRONTEND_NAMES + for provider in BACKEND_NAMES + ], + ) + def test_module_routing( + self, url_path, expected_frontend, expected_backend, context + ): context.path = url_path self.router.endpoint_routing(context) @@ -83,10 +113,13 @@ def test_endpoint_routing_with_unknown_endpoint(self, context): with pytest.raises(SATOSANoBoundEndpointError): self.router.endpoint_routing(context) - @pytest.mark.parametrize(("frontends", "backends", "micro_services"), [ - (None, None, {}), - ({}, {}, {}), - ]) + @pytest.mark.parametrize( + ("frontends", "backends", "micro_services"), + [ + (None, None, {}), + ({}, {}, {}), + ], + ) def test_bad_init(self, frontends, backends, micro_services): with pytest.raises(ValueError): ModuleRouter(frontends, backends, micro_services) diff --git a/tests/satosa/test_satosa_config.py b/tests/satosa/test_satosa_config.py index fd5045a93..e61258449 100644 --- a/tests/satosa/test_satosa_config.py +++ b/tests/satosa/test_satosa_config.py @@ -9,6 +9,7 @@ TEST_RESOURCE_BASE_PATH = os.path.join(os.path.dirname(__file__), "../test_resources") + class TestSATOSAConfig: @pytest.fixture def non_sensitive_config_dict(self): @@ -18,31 +19,35 @@ def non_sensitive_config_dict(self): "COOKIE_STATE_NAME": "TEST_STATE", "BACKEND_MODULES": [], "FRONTEND_MODULES": [], - "INTERNAL_ATTRIBUTES": {"attributes": {}} + "INTERNAL_ATTRIBUTES": {"attributes": {}}, } return config - def test_read_senstive_config_data_from_env_var(self, monkeypatch, non_sensitive_config_dict): + def test_read_senstive_config_data_from_env_var( + self, monkeypatch, non_sensitive_config_dict + ): monkeypatch.setenv("SATOSA_STATE_ENCRYPTION_KEY", "state_encryption_key") config = SATOSAConfig(non_sensitive_config_dict) assert config["STATE_ENCRYPTION_KEY"] == "state_encryption_key" - def test_senstive_config_data_from_env_var_overrides_config(self, monkeypatch, non_sensitive_config_dict): + def test_senstive_config_data_from_env_var_overrides_config( + self, monkeypatch, non_sensitive_config_dict + ): non_sensitive_config_dict["STATE_ENCRYPTION_KEY"] = "bar" monkeypatch.setenv("SATOSA_STATE_ENCRYPTION_KEY", "state_encryption_key") config = SATOSAConfig(non_sensitive_config_dict) assert config["STATE_ENCRYPTION_KEY"] == "state_encryption_key" - def test_constructor_should_raise_exception_if_sensitive_keys_are_missing(self, non_sensitive_config_dict): + def test_constructor_should_raise_exception_if_sensitive_keys_are_missing( + self, non_sensitive_config_dict + ): with pytest.raises(SATOSAConfigurationError): SATOSAConfig(non_sensitive_config_dict) - @pytest.mark.parametrize("modules_key", [ - "BACKEND_MODULES", - "FRONTEND_MODULES", - "MICRO_SERVICES" - ]) + @pytest.mark.parametrize( + "modules_key", ["BACKEND_MODULES", "FRONTEND_MODULES", "MICRO_SERVICES"] + ) def test_can_read_endpoint_configs_from_dict(self, satosa_config_dict, modules_key): expected_config = [{"foo": "bar"}, {"abc": "xyz"}] satosa_config_dict[modules_key] = expected_config @@ -50,11 +55,9 @@ def test_can_read_endpoint_configs_from_dict(self, satosa_config_dict, modules_k config = SATOSAConfig(satosa_config_dict) assert config[modules_key] == expected_config - @pytest.mark.parametrize("modules_key", [ - "BACKEND_MODULES", - "FRONTEND_MODULES", - "MICRO_SERVICES" - ]) + @pytest.mark.parametrize( + "modules_key", ["BACKEND_MODULES", "FRONTEND_MODULES", "MICRO_SERVICES"] + ) def test_can_read_endpoint_configs_from_file(self, satosa_config_dict, modules_key): satosa_config_dict[modules_key] = ["/fake_file_path"] expected_config = {"foo": "bar"} @@ -73,10 +76,10 @@ def test_can_substitute_from_environment_variable(self, monkeypatch): os.path.join(TEST_RESOURCE_BASE_PATH, "proxy_conf_environment_test.yaml") ) - assert config["COOKIE_STATE_NAME"] == 'oatmeal_raisin' + assert config["COOKIE_STATE_NAME"] == "oatmeal_raisin" def test_can_substitute_from_environment_variable_file(self, monkeypatch): - cookie_file = os.path.join(TEST_RESOURCE_BASE_PATH, 'cookie_state_name') + cookie_file = os.path.join(TEST_RESOURCE_BASE_PATH, "cookie_state_name") monkeypatch.setenv("SATOSA_COOKIE_STATE_NAME_FILE", cookie_file) config = SATOSAConfig( os.path.join( @@ -84,4 +87,4 @@ def test_can_substitute_from_environment_variable_file(self, monkeypatch): ) ) - assert config["COOKIE_STATE_NAME"] == 'chocolate_chip' + assert config["COOKIE_STATE_NAME"] == "chocolate_chip" diff --git a/tests/satosa/test_state.py b/tests/satosa/test_state.py index eadee2182..5ddefc361 100644 --- a/tests/satosa/test_state.py +++ b/tests/satosa/test_state.py @@ -8,7 +8,7 @@ import pytest -from satosa.state import State, state_to_cookie, cookie_to_state, SATOSAStateError +from satosa.state import SATOSAStateError, State, cookie_to_state, state_to_cookie def get_dict(size, key_prefix, value_preix): @@ -43,7 +43,7 @@ def get_str(length): :param length: The length of the string. :return: A string with the assigned length. """ - return ''.join(random.choice(string.ascii_lowercase) for x in range(length)) + return "".join(random.choice(string.ascii_lowercase) for x in range(length)) class TestState: @@ -100,13 +100,17 @@ def test_encode_decode_of_state(self): path = "/" encrypt_key = "2781y4hef90" - cookie = state_to_cookie(state, name=cookie_name, path=path, encryption_key=encrypt_key) + cookie = state_to_cookie( + state, name=cookie_name, path=path, encryption_key=encrypt_key + ) cookie_str = cookie[cookie_name].OutputString() loaded_state = cookie_to_state(cookie_str, cookie_name, encrypt_key) assert loaded_state[state_key] == saved_data - def test_state_to_cookie_produces_cookie_without_max_age_for_state_that_should_be_deleted(self): + def test_state_to_cookie_produces_cookie_without_max_age_for_state_that_should_be_deleted( + self, + ): state_key = "27614gjkrn" saved_data = "data" state = State() @@ -117,34 +121,41 @@ def test_state_to_cookie_produces_cookie_without_max_age_for_state_that_should_b path = "/" encrypt_key = "2781y4hef90" - cookie = state_to_cookie(state, name=cookie_name, path=path, encryption_key=encrypt_key) + cookie = state_to_cookie( + state, name=cookie_name, path=path, encryption_key=encrypt_key + ) cookie_str = cookie[cookie_name].OutputString() parsed_cookie = SimpleCookie(cookie_str) assert not parsed_cookie[cookie_name].value - assert parsed_cookie[cookie_name]["max-age"] == '0' + assert parsed_cookie[cookie_name]["max-age"] == "0" - @pytest.mark.parametrize("cookie_str, name, encryption_key, expected_exception", [ - ( # Test wrong encryption_key + @pytest.mark.parametrize( + "cookie_str, name, encryption_key, expected_exception", + [ + ( # Test wrong encryption_key 'Set-Cookie: state_cookie="_Td6WFoAAATm1rRGAgAhARYAAAB0L-WjAQCXYWt4NU9ZLWF5amdVVDdSUjhWdnkyUHE5MFhJV0J4Uzg5di1EVW1nNTR0WHZKakFsaWJmN2JMOUtlNEltMkJ0dmxOakRyUDJXZE53d0dwSGNqYnBzVng5YjVVeUYyUzkwcWVSMU42U2VNNHZDQTktUXdCQWx0WUh6LVBPX1pBYnZ1M1RsV09Qc2lKS3VpelB5a0FsMG93PT0AmlSCX0Pk2WoAAbABmAEAAGRNyZ2xxGf7AgAAAAAEWVo="; Max-Age=600; Path=/; Secure', "state_cookie", "wrong_encrypt_key", Exception, - ), - ( # Test wrong cookie_name + ), + ( # Test wrong cookie_name 'Set-Cookie: state_cookie="_Td6WFoAAATm1rRGAgAhARYAAAB0L-WjAQCXYWt4NU9ZLWF5amdVVDdSUjhWdnkyUHE5MFhJV0J4Uzg5di1EVW1nNTR0WHZKakFsaWJmN2JMOUtlNEltMkJ0dmxOakRyUDJXZE53d0dwSGNqYnBzVng5YjVVeUYyUzkwcWVSMU42U2VNNHZDQTktUXdCQWx0WUh6LVBPX1pBYnZ1M1RsV09Qc2lKS3VpelB5a0FsMG93PT0AmlSCX0Pk2WoAAbABmAEAAGRNyZ2xxGf7AgAAAAAEWVo="; Max-Age=600; Path=/; Secure', "wrong_name", "2781y4hef90", SATOSAStateError, - ), - ( # Test bad cookie str - 'not_a_cookie', + ), + ( # Test bad cookie str + "not_a_cookie", "state_cookie", "2781y4hef90", SATOSAStateError, - ), - ]) - def test_cookie_to_state_handle_broken_cookies(self, cookie_str, name, encryption_key, expected_exception): + ), + ], + ) + def test_cookie_to_state_handle_broken_cookies( + self, cookie_str, name, encryption_key, expected_exception + ): """ Test that the cookie_to_state raises exception if the input is bad """ diff --git a/tests/users.py b/tests/users.py index 22e8e309a..c018c2d39 100644 --- a/tests/users.py +++ b/tests/users.py @@ -20,7 +20,7 @@ "schacHomeOrganization": ["example.com"], "email": ["test@example.com"], "displayName": ["Test Testsson"], - "norEduPersonNIN": ["SE199012315555"] + "norEduPersonNIN": ["SE199012315555"], } } diff --git a/tests/util.py b/tests/util.py index c26c796fe..dd81afa51 100644 --- a/tests/util.py +++ b/tests/util.py @@ -6,21 +6,24 @@ from datetime import datetime from urllib.parse import parse_qsl, urlparse -from Cryptodome.PublicKey import RSA from bs4 import BeautifulSoup -from saml2 import server, BINDING_HTTP_POST, BINDING_HTTP_REDIRECT -from saml2.authn_context import AuthnBroker, authn_context_class_ref, PASSWORD +from Cryptodome.PublicKey import RSA +from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT, server +from saml2.authn_context import PASSWORD, AuthnBroker, authn_context_class_ref from saml2.cert import OpenSSLWrapper from saml2.client import Saml2Client from saml2.config import Config from saml2.metadata import entity_descriptor -from saml2.saml import name_id_from_string, NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_PERSISTENT +from saml2.saml import ( + NAMEID_FORMAT_PERSISTENT, + NAMEID_FORMAT_TRANSIENT, + name_id_from_string, +) from saml2.samlp import NameIDPolicy from satosa.backends.base import BackendModule from satosa.frontends.base import FrontendModule -from satosa.internal import AuthenticationInformation -from satosa.internal import InternalData +from satosa.internal import AuthenticationInformation, InternalData from satosa.micro_services.base import RequestMicroService, ResponseMicroService from satosa.response import Response @@ -37,9 +40,15 @@ def __init__(self, config): """ Saml2Client.__init__(self, config) - def make_auth_req(self, entity_id, nameid_format=None, relay_state="relay_state", - request_binding=BINDING_HTTP_REDIRECT, response_binding=BINDING_HTTP_REDIRECT, - subject=None): + def make_auth_req( + self, + entity_id, + nameid_format=None, + relay_state="relay_state", + request_binding=BINDING_HTTP_REDIRECT, + response_binding=BINDING_HTTP_REDIRECT, + subject=None, + ): """ :type entity_id: str :rtype: str @@ -49,23 +58,20 @@ def make_auth_req(self, entity_id, nameid_format=None, relay_state="relay_state" """ # Picks a binding to use for sending the Request to the IDP _binding, destination = self.pick_binding( - 'single_sign_on_service', - [request_binding], 'idpsso', - entity_id=entity_id) + "single_sign_on_service", [request_binding], "idpsso", entity_id=entity_id + ) kwargs = {} if subject: - kwargs['subject'] = subject + kwargs["subject"] = subject req_id, req = self.create_authn_request( - destination, - binding=response_binding, - nameid_format=nameid_format, - **kwargs + destination, binding=response_binding, nameid_format=nameid_format, **kwargs ) - ht_args = self.apply_binding(_binding, '%s' % req, destination, - relay_state=relay_state) + ht_args = self.apply_binding( + _binding, "%s" % req, destination, relay_state=relay_state + ) if _binding == BINDING_HTTP_POST: form_post_html = "\n".join(ht_args["data"]) @@ -94,8 +100,14 @@ def __init__(self, user_db, config): server.Server.__init__(self, config=config) self.user_db = user_db - def __create_authn_response(self, saml_request, relay_state, binding, - userid, response_binding=BINDING_HTTP_POST): + def __create_authn_response( + self, + saml_request, + relay_state, + binding, + userid, + response_binding=BINDING_HTTP_POST, + ): """ Handles a SAML request, validates and creates a SAML response but does not apply the binding to encode it. @@ -116,28 +128,29 @@ def __create_authn_response(self, saml_request, relay_state, binding, """ auth_req = self.parse_authn_request(saml_request, binding) binding_out, destination = self.pick_binding( - 'assertion_consumer_service', + "assertion_consumer_service", bindings=[response_binding], - entity_id=auth_req.message.issuer.text, request=auth_req.message) + entity_id=auth_req.message.issuer.text, + request=auth_req.message, + ) resp_args = self.response_args(auth_req.message) authn_broker = AuthnBroker() - authn_broker.add(authn_context_class_ref(PASSWORD), lambda: None, 10, - 'unittest_idp.xml') + authn_broker.add( + authn_context_class_ref(PASSWORD), lambda: None, 10, "unittest_idp.xml" + ) authn_broker.get_authn_by_accr(PASSWORD) - resp_args['authn'] = authn_broker.get_authn_by_accr(PASSWORD) + resp_args["authn"] = authn_broker.get_authn_by_accr(PASSWORD) - resp = self.create_authn_response(self.user_db[userid], - userid=userid, - **resp_args) + resp = self.create_authn_response( + self.user_db[userid], userid=userid, **resp_args + ) return destination, resp - def __apply_binding_to_authn_response(self, - resp, - response_binding, - relay_state, - destination): + def __apply_binding_to_authn_response( + self, resp, response_binding, relay_state, destination + ): """ Applies the binding to the response. """ @@ -146,18 +159,22 @@ def __apply_binding_to_authn_response(self, resp = {"SAMLResponse": saml_response, "RelayState": relay_state} elif response_binding == BINDING_HTTP_REDIRECT: http_args = self.apply_binding( - response_binding, - '%s' % resp, - destination, - relay_state, - response=True) - resp = dict(parse_qsl(urlparse( - dict(http_args["headers"])["Location"]).query)) + response_binding, "%s" % resp, destination, relay_state, response=True + ) + resp = dict( + parse_qsl(urlparse(dict(http_args["headers"])["Location"]).query) + ) return resp - def handle_auth_req(self, saml_request, relay_state, binding, userid, - response_binding=BINDING_HTTP_POST): + def handle_auth_req( + self, + saml_request, + relay_state, + binding, + userid, + response_binding=BINDING_HTTP_POST, + ): """ Handles a SAML request, validates and creates a SAML response. :type saml_request: str @@ -176,22 +193,23 @@ def handle_auth_req(self, saml_request, relay_state, binding, userid, """ destination, _resp = self.__create_authn_response( - saml_request, - relay_state, - binding, - userid, - response_binding) + saml_request, relay_state, binding, userid, response_binding + ) resp = self.__apply_binding_to_authn_response( - _resp, - response_binding, - relay_state, - destination) + _resp, response_binding, relay_state, destination + ) return destination, resp - def handle_auth_req_no_name_id(self, saml_request, relay_state, binding, - userid, response_binding=BINDING_HTTP_POST): + def handle_auth_req_no_name_id( + self, + saml_request, + relay_state, + binding, + userid, + response_binding=BINDING_HTTP_POST, + ): """ Handles a SAML request, validates and creates a SAML response but without a element. @@ -211,20 +229,15 @@ def handle_auth_req_no_name_id(self, saml_request, relay_state, binding, """ destination, _resp = self.__create_authn_response( - saml_request, - relay_state, - binding, - userid, - response_binding) + saml_request, relay_state, binding, userid, response_binding + ) # Remove the element from the response. _resp.assertion.subject.name_id = None resp = self.__apply_binding_to_authn_response( - _resp, - response_binding, - relay_state, - destination) + _resp, response_binding, relay_state, destination + ) return destination, resp @@ -242,7 +255,7 @@ def generate_cert(): "state": "ac", "city": "Umea", "organization": "ITS", - "organization_unit": "DIRG" + "organization_unit": "DIRG", } osw = OpenSSLWrapper() cert_str, key_str = osw.create_certificate(cert_info, request=False) @@ -261,11 +274,12 @@ class FileGenerator(object): """ Creates different types of temporary files that is useful for testing. """ + _instance = None def __init__(self): if FileGenerator._instance: - raise TypeError('Singletons must be accessed through `get_instance()`.') + raise TypeError("Singletons must be accessed through `get_instance()`.") else: FileGenerator._instance = self self.generate_certs = {} @@ -331,9 +345,9 @@ def create_metadata(self, config, code=None): def private_to_public_key(pk_file): - f = open(pk_file, 'r') + f = open(pk_file, "r") pk = RSA.importKey(f.read()) - return pk.publickey().exportKey('PEM') + return pk.publickey().exportKey("PEM") def create_name_id(): @@ -373,9 +387,14 @@ def create_name_id_policy_persistent(): class FakeBackend(BackendModule): - def __init__(self, start_auth_func=None, internal_attributes=None, - base_url="", name="FakeBackend", - register_endpoints_func=None): + def __init__( + self, + start_auth_func=None, + internal_attributes=None, + base_url="", + name="FakeBackend", + register_endpoints_func=None, + ): super().__init__(None, internal_attributes, base_url, name) self.start_auth_func = start_auth_func @@ -410,10 +429,15 @@ class FakeFrontend(FrontendModule): TODO comment """ - def __init__(self, handle_authn_request_func=None, internal_attributes=None, - base_url="", name="FakeFrontend", - handle_authn_response_func=None, - register_endpoints_func=None): + def __init__( + self, + handle_authn_request_func=None, + internal_attributes=None, + base_url="", + name="FakeFrontend", + handle_authn_response_func=None, + register_endpoints_func=None, + ): super().__init__(None, internal_attributes, base_url, name) self.handle_authn_request_func = handle_authn_request_func self.handle_authn_response_func = handle_authn_response_func @@ -468,7 +492,9 @@ def start_auth(self, context, internal_request): return Response("Auth request received, passed to test backend") def handle_response(self, context): - auth_info = AuthenticationInformation("test", str(datetime.now()), "test_issuer") + auth_info = AuthenticationInformation( + "test", str(datetime.now()), "test_issuer" + ) internal_resp = InternalData(auth_info=auth_info) internal_resp.attributes = context.request internal_resp.subject_id = "test_user" @@ -478,11 +504,16 @@ def handle_response(self, context): class TestFrontend(FrontendModule): __test__ = False - def __init__(self, auth_req_callback_func, internal_attributes, config, base_url, name): + def __init__( + self, auth_req_callback_func, internal_attributes, config, base_url, name + ): super().__init__(auth_req_callback_func, internal_attributes, base_url, name) def register_endpoints(self, backend_names): - url_map = [("^{}/{}/request$".format(p, self.name), self.handle_request) for p in backend_names] + url_map = [ + ("^{}/{}/request$".format(p, self.name), self.handle_request) + for p in backend_names + ] return url_map def handle_request(self, context):