Skip to content

Commit

Permalink
refactor: sort imports and apply the Black style
Browse files Browse the repository at this point in the history
  • Loading branch information
niheconomoum committed Jan 25, 2024
1 parent bcef9c7 commit 24e8ff3
Show file tree
Hide file tree
Showing 89 changed files with 4,599 additions and 2,810 deletions.
26 changes: 14 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
setup.py
"""

from setuptools import setup, find_packages
from setuptools import find_packages, setup

setup(
name='SATOSA',
version='8.4.0',
description='Protocol proxy (SAML/OIDC).',
author='DIRG',
author_email='[email protected]',
license='Apache 2.0',
url='https://github.com/SUNET/SATOSA',
packages=find_packages('src/'),
package_dir={'': 'src'},
name="SATOSA",
version="8.4.0",
description="Protocol proxy (SAML/OIDC).",
author="DIRG",
author_email="[email protected]",
license="Apache 2.0",
url="https://github.com/SUNET/SATOSA",
packages=find_packages("src/"),
package_dir={"": "src"},
install_requires=[
"pyop >= v3.4.0",
"pysaml2 >= 6.5.1",
Expand Down Expand Up @@ -44,6 +44,8 @@
"Programming Language :: Python :: 3.11",
],
entry_points={
"console_scripts": ["satosa-saml-metadata=satosa.scripts.satosa_saml_metadata:construct_saml_metadata"]
}
"console_scripts": [
"satosa-saml-metadata=satosa.scripts.satosa_saml_metadata:construct_saml_metadata"
]
},
)
89 changes: 65 additions & 24 deletions src/satosa/attribute_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def scope(s):
:param s: string to extract scope from (filtered string in mako template)
:return: the scope
"""
if '@' not in s:
if "@" not in s:
raise ValueError("Unscoped string")
(local_part, _, domain_part) = s.partition('@')
(local_part, _, domain_part) = s.partition("@")
return domain_part


Expand All @@ -31,16 +31,22 @@ def __init__(self, internal_attributes):
:param internal_attributes: A map of how to convert the attributes
(dict[internal_name, dict[attribute_profile, external_name]])
"""
self.separator = "." # separator for nested attribute values, e.g. address.street_address
self.multivalue_separator = ";" # separates multiple values, e.g. when using templates
self.separator = (
"." # separator for nested attribute values, e.g. address.street_address
)
self.multivalue_separator = (
";" # separates multiple values, e.g. when using templates
)
self.from_internal_attributes = internal_attributes["attributes"]
self.template_attributes = internal_attributes.get("template_attributes", None)

self.to_internal_attributes = defaultdict(dict)
for internal_attribute_name, mappings in self.from_internal_attributes.items():
for profile, external_attribute_names in mappings.items():
for external_attribute_name in external_attribute_names:
self.to_internal_attributes[profile][external_attribute_name] = internal_attribute_name
self.to_internal_attributes[profile][
external_attribute_name
] = internal_attribute_name

def to_internal_filter(self, attribute_profile, external_attribute_names):
"""
Expand All @@ -59,7 +65,11 @@ def to_internal_filter(self, attribute_profile, external_attribute_names):
try:
profile_mapping = self.to_internal_attributes[attribute_profile]
except KeyError:
logline = "no attribute mapping found for the given attribute profile {}".format(attribute_profile)
logline = (
"no attribute mapping found for the given attribute profile {}".format(
attribute_profile
)
)
logger.warn(logline)
# no attributes since the given profile is not configured
return []
Expand Down Expand Up @@ -103,7 +113,9 @@ def to_internal(self, attribute_profile, external_dict):
)
if attribute_values: # Only insert key if it has some values
logline = "backend attribute {external} mapped to {internal} ({value})".format(
external=external_attribute_name, internal=internal_attribute_name, value=attribute_values
external=external_attribute_name,
internal=internal_attribute_name,
value=attribute_values,
)
logger.debug(logline)
internal_dict[internal_attribute_name] = attribute_values
Expand All @@ -112,7 +124,9 @@ def to_internal(self, attribute_profile, external_dict):
external_attribute_name
)
logger.debug(logline)
internal_dict = self._handle_template_attributes(attribute_profile, internal_dict)
internal_dict = self._handle_template_attributes(
attribute_profile, internal_dict
)
return internal_dict

def _collate_attribute_values_by_priority_order(self, attribute_names, data):
Expand All @@ -128,7 +142,11 @@ def _collate_attribute_values_by_priority_order(self, attribute_names, data):
return result

def _render_attribute_template(self, template, data):
t = Template(template, cache_enabled=True, imports=["from satosa.attribute_mapping import scope"])
t = Template(
template,
cache_enabled=True,
imports=["from satosa.attribute_mapping import scope"],
)
try:
return t.render(**data).split(self.multivalue_separator)
except (NameError, TypeError):
Expand All @@ -144,11 +162,19 @@ def _handle_template_attributes(self, attribute_profile, internal_dict):
continue

external_attribute_name = mapping[attribute_profile]
templates = [t for t in external_attribute_name if "$" in t] # these looks like templates...
template_attribute_values = [self._render_attribute_template(template, internal_dict) for template in
templates]
flattened_attribute_values = list(chain.from_iterable(template_attribute_values))
attribute_values = flattened_attribute_values or internal_dict.get(internal_attribute_name, None)
templates = [
t for t in external_attribute_name if "$" in t
] # these looks like templates...
template_attribute_values = [
self._render_attribute_template(template, internal_dict)
for template in templates
]
flattened_attribute_values = list(
chain.from_iterable(template_attribute_values)
)
attribute_values = flattened_attribute_values or internal_dict.get(
internal_attribute_name, None
)
if attribute_values: # only insert key if it has some values
internal_dict[internal_attribute_name] = attribute_values

Expand All @@ -172,7 +198,9 @@ def _create_nested_attribute_value(self, nested_attribute_names, value):
return {nested_attribute_names[0]: value}

# keep digging further into the nested attribute names
child_dict = self._create_nested_attribute_value(nested_attribute_names[1:], value)
child_dict = self._create_nested_attribute_value(
nested_attribute_names[1:], value
)
return {nested_attribute_names[0]: child_dict}

def from_internal(self, attribute_profile, internal_dict):
Expand All @@ -190,10 +218,14 @@ def from_internal(self, attribute_profile, internal_dict):
external_dict = {}
for internal_attribute_name in internal_dict:
try:
attribute_mapping = self.from_internal_attributes[internal_attribute_name]
except KeyError:
logline = "no attribute mapping found for the internal attribute {}".format(
attribute_mapping = self.from_internal_attributes[
internal_attribute_name
]
except KeyError:
logline = (
"no attribute mapping found for the internal attribute {}".format(
internal_attribute_name
)
)
logger.debug(logline)
continue
Expand All @@ -206,20 +238,29 @@ def from_internal(self, attribute_profile, internal_dict):
logger.debug(logline)
continue

external_attribute_names = self.from_internal_attributes[internal_attribute_name][attribute_profile]
external_attribute_names = self.from_internal_attributes[
internal_attribute_name
][attribute_profile]
# select the first attribute name
external_attribute_name = external_attribute_names[0]
logline = "frontend attribute {external} mapped from {internal} ({value})".format(
external=external_attribute_name, internal=internal_attribute_name, value=internal_dict[internal_attribute_name]
logline = (
"frontend attribute {external} mapped from {internal} ({value})".format(
external=external_attribute_name,
internal=internal_attribute_name,
value=internal_dict[internal_attribute_name],
)
)
logger.debug(logline)

if self.separator in external_attribute_name:
nested_attribute_names = external_attribute_name.split(self.separator)
nested_dict = self._create_nested_attribute_value(nested_attribute_names[1:],
internal_dict[internal_attribute_name])
nested_dict = self._create_nested_attribute_value(
nested_attribute_names[1:], internal_dict[internal_attribute_name]
)
external_dict[nested_attribute_names[0]] = nested_dict
else:
external_dict[external_attribute_name] = internal_dict[internal_attribute_name]
external_dict[external_attribute_name] = internal_dict[
internal_attribute_name
]

return external_dict
10 changes: 6 additions & 4 deletions src/satosa/backends/apple.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""
Apple backend module.
"""
import json
import logging
from .openid_connect import OpenIDConnectBackend, STATE_KEY

import requests
from oic.oauth2.message import Message
from oic.oic.message import AuthorizationResponse

import satosa.logging_util as lu
from ..exception import SATOSAAuthenticationError
import json
import requests

from ..exception import SATOSAAuthenticationError
from .openid_connect import STATE_KEY, OpenIDConnectBackend

logger = logging.getLogger(__name__)

Expand Down
61 changes: 38 additions & 23 deletions src/satosa/backends/bitbucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"""
import json
import logging
import requests

from oic.utils.authn.authn_context import UNSPECIFIED
import requests
from oic.oauth2.consumer import stateID
from oic.utils.authn.authn_context import UNSPECIFIED

from satosa.backends.oauth import _OAuthBackend
from satosa.internal import AuthenticationInformation
Expand Down Expand Up @@ -37,10 +37,17 @@ def __init__(self, outgoing, internal_attributes, config, base_url, name):
:type base_url: str
:type name: str
"""
config.setdefault('response_type', 'code')
config['verify_accesstoken_state'] = False
super().__init__(outgoing, internal_attributes, config, base_url,
name, 'bitbucket', 'account_id')
config.setdefault("response_type", "code")
config["verify_accesstoken_state"] = False
super().__init__(
outgoing,
internal_attributes,
config,
base_url,
name,
"bitbucket",
"account_id",
)

def get_request_args(self, get_state=stateID):
request_args = super().get_request_args(get_state=get_state)
Expand All @@ -58,28 +65,36 @@ def get_request_args(self, get_state=stateID):

def auth_info(self, request):
return AuthenticationInformation(
UNSPECIFIED, None,
self.config['server_info']['authorization_endpoint'])
UNSPECIFIED, None, self.config["server_info"]["authorization_endpoint"]
)

def user_information(self, access_token):
url = self.config['server_info']['user_endpoint']
url = self.config["server_info"]["user_endpoint"]
email_url = "{}/emails".format(url)
headers = {'Authorization': 'Bearer {}'.format(access_token)}
headers = {"Authorization": "Bearer {}".format(access_token)}
resp = requests.get(url, headers=headers)
data = json.loads(resp.text)
if 'email' in self.config['scope']:
if "email" in self.config["scope"]:
resp = requests.get(email_url, headers=headers)
emails = json.loads(resp.text)
data.update({
'email': [e for e in [d.get('email')
for d in emails.get('values')
if d.get('is_primary')
]
],
'email_confirmed': [e for e in [d.get('email')
for d in emails.get('values')
if d.get('is_confirmed')
]
]
})
data.update(
{
"email": [
e
for e in [
d.get("email")
for d in emails.get("values")
if d.get("is_primary")
]
],
"email_confirmed": [
e
for e in [
d.get("email")
for d in emails.get("values")
if d.get("is_confirmed")
]
],
}
)
return data
Loading

0 comments on commit 24e8ff3

Please sign in to comment.