Skip to content

Commit

Permalink
feat: add middleware for checking allowed data fields
Browse files Browse the repository at this point in the history
When calling profile any kind there should be checked to which fields
the service has access rights by using the allowed data fields.
This adds middleware and mixin class for Profile model and
VerifiedPersonalInfo models for checking that the queried fields are
allowed for the service.

Refs HP-2319
  • Loading branch information
nicobav committed May 23, 2024
1 parent da711f1 commit ddda974
Show file tree
Hide file tree
Showing 20 changed files with 724 additions and 94 deletions.
19 changes: 19 additions & 0 deletions open_city_profile/graphene.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from django.conf import settings
from django.forms import MultipleChoiceField
from django_filters import MultipleChoiceFilter
from graphene.utils.str_converters import to_snake_case
from graphene_django import DjangoObjectType
from graphene_django.forms.converter import convert_form_field
from graphene_django.types import ALL_FIELDS
from graphql_sync_dataloaders import SyncDataLoader
from parler.models import TranslatableModel

from open_city_profile.exceptions import FieldNotAllowedError, ServiceNotIdentifiedError
from profiles.loaders import (
addresses_by_profile_id_loader,
emails_by_profile_id_loader,
Expand Down Expand Up @@ -178,3 +180,20 @@ def __init_subclass_with_meta__(
_meta=_meta,
**options,
)


class AllowedDataFieldsMiddleware:

def resolve(self, next, root, info, **kwargs):
if getattr(root, "check_allowed_data_fields", False):
field_name = to_snake_case(getattr(info, "field_name", ""))

if not getattr(info.context, "service", False):
raise ServiceNotIdentifiedError("Service not identified")

if not root.is_field_allowed_for_service(field_name, info.context.service):
raise FieldNotAllowedError(
"Field is not allowed for service.", field_name=field_name
)

return next(root, info, **kwargs)
1 change: 1 addition & 0 deletions open_city_profile/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@
"SCHEMA": "open_city_profile.schema.schema",
"MIDDLEWARE": [
# NOTE: Graphene runs its middlewares in reverse order!
"open_city_profile.graphene.AllowedDataFieldsMiddleware",
"open_city_profile.graphene.JWTMiddleware",
"open_city_profile.graphene.GQLDataLoaders",
],
Expand Down
15 changes: 13 additions & 2 deletions open_city_profile/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
UserFactory,
)
from open_city_profile.views import GraphQLView
from services.tests.factories import ServiceFactory
from services.models import Service
from services.tests.factories import AllowedDataFieldFactory, ServiceFactory

_not_provided = object()

Expand All @@ -34,6 +35,7 @@ def execute(
auth_token_payload=None,
service=_not_provided,
context=None,
allowed_data_fields: list[str] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -63,9 +65,18 @@ def execute(
context.service = None

if service is _not_provided:
context.service = ServiceFactory(name="profile", is_profile_service=True)
service = Service.objects.filter(is_profile_service=True).first()
if not service:
service = ServiceFactory(name="profile", is_profile_service=True)

context.service = service
elif service:
context.service = service
if allowed_data_fields:
for field_name in allowed_data_fields:
context.service.allowed_data_fields.add(
AllowedDataFieldFactory(field_name=field_name)
)

return super().execute(
*args,
Expand Down
20 changes: 19 additions & 1 deletion open_city_profile/tests/graphql_test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
from django.utils.crypto import get_random_string
from jose import jwt

from services.tests.factories import ServiceClientIdFactory
from services.tests.factories import (
AllowedDataFieldFactory,
ServiceClientIdFactory,
ServiceConnectionFactory,
)

from .conftest import get_unix_timestamp_now
from .keys import rsa_key

AUDIENCE = getattr(settings, "OIDC_API_TOKEN_AUTH")["AUDIENCE"]
ISSUER = getattr(settings, "OIDC_API_TOKEN_AUTH")["ISSUER"]
if isinstance(ISSUER, list):
ISSUER = ISSUER[0]

CONFIG_URL = f"{ISSUER}/.well-known/openid-configuration"
JWKS_URL = f"{ISSUER}/jwks"
Expand Down Expand Up @@ -129,6 +135,18 @@ def do_graphql_call_as_user(
service_client_id = ServiceClientIdFactory(
service__service_type=None, service__is_profile_service=True
)
if getattr(user, "profile", None):
ServiceConnectionFactory(
profile=user.profile, service=service_client_id.service
)
service_client_id.service.allowed_data_fields.add(
AllowedDataFieldFactory(field_name="name"),
AllowedDataFieldFactory(field_name="address"),
AllowedDataFieldFactory(field_name="email"),
AllowedDataFieldFactory(field_name="phone"),
AllowedDataFieldFactory(field_name="personalidentitycode"),
)

elif service:
service_client_id = service.client_ids.first()

Expand Down
79 changes: 76 additions & 3 deletions profiles/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from encrypted_fields import fields
from enumfields import EnumField

from services.models import ServiceConnection
from services.models import Service, ServiceConnection
from users.models import User
from utils.fields import (
NullToEmptyCharField,
Expand All @@ -28,7 +28,50 @@
)


class Profile(UUIDModel, SerializableMixin):
class AllowedDataFieldsMixin:
"""
Mixin class for checking allowed data fields per service.
`allowed_data_fields_map` is a dictionary where the key is the `field_name` of the allowed data field
`allowed_data_fields.json` and the value is an iterable of django model's field names that the `field_name`
describes. For example, if the `field_name` is `name`, the value could be `("first_name", "last_name")`.
e.g:
allowed_data_fields_map = {
"name": ("first_name", "last_name", "nickname"),
"personalidentitycode": ("national_identification_number",),
"address": ("address", "postal_code", "city", "country_code")
}
`always_allow_fields`: Since connections are not defined in `allowed_data_fields.json` they should be
defined here. If the field is connection and the node does not inherit this mixin the data will be available
to all services.
"""

allowed_data_fields_map = {}
always_allow_fields = ["id", "service_connections"]
check_allowed_data_fields = True

@classmethod
def is_field_allowed_for_service(cls, field_name: str, service: Service):
if not service:
raise ValueError("No service identified")

if field_name in cls.always_allow_fields:
return True

allowed_data_fields = service.allowed_data_fields.values_list(
"field_name", flat=True
)
return any(
field_name in cls.allowed_data_fields_map.get(allowed_data_field, [])
for allowed_data_field in allowed_data_fields
)

class Meta:
abstract = True


class Profile(UUIDModel, SerializableMixin, AllowedDataFieldsMixin):
user = models.OneToOneField(User, on_delete=models.PROTECT, null=True, blank=True)
first_name = NullToEmptyCharField(max_length=150, blank=True, db_index=True)
last_name = NullToEmptyCharField(max_length=150, blank=True, db_index=True)
Expand Down Expand Up @@ -63,6 +106,24 @@ class Meta:
)
audit_log = True

# AllowedDataField configs
allowed_data_fields_map = {
"name": (
"first_name",
"last_name",
"nickname",
),
"email": ("emails", "primary_email"),
"phone": ("phones", "primary_phone"),
"address": ("addresses", "primary_address"),
"personalidentitycode": ("sensitivedata",),
}
always_allow_fields = AllowedDataFieldsMixin.always_allow_fields + [
"verified_personal_information",
"language",
"contact_method",
]

def resolve_profile(self):
return self

Expand Down Expand Up @@ -178,7 +239,7 @@ def get_national_identification_number_hash_key():
return settings.SALT_NATIONAL_IDENTIFICATION_NUMBER


class VerifiedPersonalInformation(SerializableMixin):
class VerifiedPersonalInformation(SerializableMixin, AllowedDataFieldsMixin):
profile = models.OneToOneField(
Profile, on_delete=models.CASCADE, related_name="verified_personal_information"
)
Expand Down Expand Up @@ -237,6 +298,18 @@ class VerifiedPersonalInformation(SerializableMixin):
)
audit_log = True

allowed_data_fields_map = {
"name": ("first_name", "last_name", "given_name"),
"personalidentitycode": ("national_identification_number",),
"address": (
"municipality_of_residence",
"municipality_of_residence_number",
"permanent_address",
"temporary_address",
"permanent_foreign_address",
),
}

class Meta:
permissions = [
(
Expand Down
56 changes: 50 additions & 6 deletions profiles/tests/test_gql_claim_profile_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_user_can_claim_claimable_profile_without_existing_profile(
}
}
}
executed = user_gql_client.execute(query)
executed = user_gql_client.execute(query, allowed_data_fields=["name"])

assert "errors" not in executed
assert dict(executed["data"]) == expected_data
Expand Down Expand Up @@ -98,6 +98,42 @@ def test_user_can_claim_claimable_profile_without_existing_profile(
"""


def test_user_cant_get_fields_not_allowed_when_claiming_a_profile(user_gql_client):
profile = ProfileWithPrimaryEmailFactory(user=None)
claim_token = ClaimTokenFactory(profile=profile)

t = Template(
"""
mutation {
claimProfile(
input: {
token: "${claimToken}",
profile: {
firstName: "Joe",
nickname: "Joey"
}
}
) {
profile {
id
firstName
lastName
nickname
sensitivedata {
ssn
}
}
}
}
"""
)
query = t.substitute(claimToken=claim_token.token)
executed = user_gql_client.execute(query, allowed_data_fields=["name"])

assert "errors" in executed
assert_match_error_code(executed, "FIELD_NOT_ALLOWED_ERROR")


def test_can_not_change_primary_email_to_non_primary(user_gql_client):
profile = ProfileFactory(user=None)
email = EmailFactory(profile=profile, primary=True)
Expand All @@ -112,7 +148,9 @@ def test_can_not_change_primary_email_to_non_primary(user_gql_client):
},
}

executed = user_gql_client.execute(CLAIM_PROFILE_MUTATION, variables=variables)
executed = user_gql_client.execute(
CLAIM_PROFILE_MUTATION, variables=variables, allowed_data_fields=["email"]
)
assert_match_error_code(executed, "PROFILE_MUST_HAVE_PRIMARY_EMAIL")


Expand All @@ -128,6 +166,7 @@ def test_can_not_delete_primary_email(user_gql_client):
"token": str(claim_token.token),
"profileInput": {"removeEmails": email_deletes},
},
allowed_data_fields=["email"],
)
assert_match_error_code(executed, "PROFILE_MUST_HAVE_PRIMARY_EMAIL")

Expand Down Expand Up @@ -172,6 +211,7 @@ def test_changing_an_email_address_marks_it_unverified(
CLAIM_PROFILE_MUTATION,
variables=variables,
execution_context_class=execution_context_class,
allowed_data_fields=["email"],
)
assert "errors" not in executed
assert executed["data"] == expected_data
Expand All @@ -189,7 +229,9 @@ def execute_query(self, user_gql_client, profile_input):
"profileInput": profile_input,
}

return user_gql_client.execute(CLAIM_PROFILE_MUTATION, variables=variables)
return user_gql_client.execute(
CLAIM_PROFILE_MUTATION, variables=variables, allowed_data_fields=["email"]
)


def test_user_cannot_claim_claimable_profile_if_token_expired(user_gql_client):
Expand Down Expand Up @@ -223,7 +265,7 @@ def test_user_cannot_claim_claimable_profile_if_token_expired(user_gql_client):
"""
)
query = t.substitute(claimToken=expired_claim_token.token)
executed = user_gql_client.execute(query)
executed = user_gql_client.execute(query, allowed_data_fields=["name"])

assert "errors" in executed
assert executed["errors"][0]["extensions"]["code"] == TOKEN_EXPIRED_ERROR
Expand All @@ -237,7 +279,9 @@ def test_using_non_existing_token_produces_an_object_does_not_exist_error(
variables = {
"token": non_existing_token,
}
executed = user_gql_client.execute(CLAIM_PROFILE_MUTATION, variables=variables)
executed = user_gql_client.execute(
CLAIM_PROFILE_MUTATION, variables=variables, allowed_data_fields=["email"]
)

assert_match_error_code(executed, "OBJECT_DOES_NOT_EXIST_ERROR")

Expand Down Expand Up @@ -270,7 +314,7 @@ def test_user_cannot_claim_claimable_profile_with_existing_profile(user_gql_clie
"""
)
query = t.substitute(claimToken=claim_token.token)
executed = user_gql_client.execute(query)
executed = user_gql_client.execute(query, allowed_data_fields=["name"])

assert "errors" in executed
assert (
Expand Down
2 changes: 1 addition & 1 deletion profiles/tests/test_gql_claimable_profile_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_can_query_claimable_profile_with_token(user_gql_client):
"lastName": profile.last_name,
}
}
executed = user_gql_client.execute(query)
executed = user_gql_client.execute(query, allowed_data_fields=["name"])

assert "errors" not in executed
assert dict(executed["data"]) == expected_data
Expand Down
Loading

0 comments on commit ddda974

Please sign in to comment.