From 5035c8f2a51c7102eecb99896859c0b8f8dfd29a Mon Sep 17 00:00:00 2001 From: Nico Virkki Date: Fri, 20 Dec 2024 10:12:06 +0200 Subject: [PATCH] feat: hide all information_system attributes from es search (#415) Previously we restricted the information system from basic filters and views for unauthenticated users. This adds hiding the InformationSystem attribues from unauthenticated users in elastic backed searches as well. Refs TIED-171 --- search_indices/serializers/base.py | 7 +- search_indices/serializers/utils.py | 17 +- search_indices/tests/conftest.py | 29 +- search_indices/tests/test_elastic_api.py | 549 ++++++++++++++++++----- search_indices/views/base.py | 24 +- 5 files changed, 500 insertions(+), 126 deletions(-) diff --git a/search_indices/serializers/base.py b/search_indices/serializers/base.py index 5c98efa6..559960c0 100644 --- a/search_indices/serializers/base.py +++ b/search_indices/serializers/base.py @@ -30,5 +30,10 @@ class Meta: def get_score(self, obj: Hit) -> int: return obj.meta.score + @property + def is_authenticated(self): + request = self.context.get("request") + return bool(request and request.user.is_authenticated) + def get_attributes(self, obj: Hit) -> Optional[dict]: - return get_attributes(obj, "attributes") + return get_attributes(obj, "attributes", self.is_authenticated) diff --git a/search_indices/serializers/utils.py b/search_indices/serializers/utils.py index eb17d972..4efd9e26 100644 --- a/search_indices/serializers/utils.py +++ b/search_indices/serializers/utils.py @@ -2,8 +2,18 @@ from elasticsearch_dsl.response.hit import Hit +attributes_for_authenticated = ( + "function_InformationSystem", + "action_InformationSystem", + "classification_InformationSystem", + "record_InformationSystem", + "phase_InformationSystem", +) -def get_attributes(obj: Hit, attribute_field_name: str) -> Optional[dict]: + +def get_attributes( + obj: Hit, attribute_field_name: str, authenticated: bool +) -> Optional[dict]: """ Fetch attributes from index and revert the attribute names that have "." replaced with "+". @@ -14,6 +24,11 @@ def get_attributes(obj: Hit, attribute_field_name: str) -> Optional[dict]: attrs = attrs.to_dict() for key, value in attrs.items(): key = key.replace("+", ".") + + if not authenticated and key in attributes_for_authenticated: + continue + attributes[key] = value + return attributes return None diff --git a/search_indices/tests/conftest.py b/search_indices/tests/conftest.py index a9a36ea5..27be13d4 100644 --- a/search_indices/tests/conftest.py +++ b/search_indices/tests/conftest.py @@ -7,6 +7,7 @@ from elasticsearch import Elasticsearch from elasticsearch_dsl.connections import add_connection from pytest import fixture +from rest_framework.test import APIClient from metarecord.models import Action, Classification, Function, Phase, Record from metarecord.tests.conftest import user, user_api_client # noqa @@ -57,7 +58,7 @@ def destroy_indices(): RecordDocument._index.delete(ignore=[400, 404]) -@fixture(scope="session", autouse=True) +@fixture(scope="class", autouse=True) def create_indices(): """ Initialize all indices with the custom analyzers. @@ -82,6 +83,11 @@ def es_connection(): yield es_connection +@fixture +def api_client(): + return APIClient() + + @fixture def action(phase): return Action.objects.create( @@ -89,6 +95,13 @@ def action(phase): ) +@fixture +def action_with_information_system(phase): + return Action.objects.create( + attributes={"InformationSystem": "xyz"}, phase=phase, index=1 + ) + + @fixture def action_2(phase_2): return Action.objects.create( @@ -124,6 +137,13 @@ def function(classification): ) +@fixture +def function_with_information_system(classification_2): + return Function.objects.create( + attributes={"InformationSystem": "xyz"}, classification=classification_2 + ) + + @fixture def function_2(classification_2): return Function.objects.create( @@ -139,6 +159,13 @@ def phase(function): ) +@fixture +def phase_with_information_system(function): + return Phase.objects.create( + attributes={"InformationSystem": "xyz"}, function=function, index=1 + ) + + @fixture def phase_2(function_2): return Phase.objects.create( diff --git a/search_indices/tests/test_elastic_api.py b/search_indices/tests/test_elastic_api.py index 93eeae65..7833fbb2 100644 --- a/search_indices/tests/test_elastic_api.py +++ b/search_indices/tests/test_elastic_api.py @@ -1,8 +1,7 @@ import pytest from rest_framework.reverse import reverse -from rest_framework.test import APIClient -from metarecord.models import Record +from metarecord.models import Phase, Record ACTION_LIST_URL = reverse("action_search-list") ALL_LIST_URL = reverse("all_search-list") @@ -13,136 +12,452 @@ @pytest.mark.django_db -def test_classification_search_exact(user_api_client, classification): - url = ALL_LIST_URL + "?search=testisana" - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert classification.uuid.hex in uuids - - -@pytest.mark.django_db -def test_classification_search_fuzzy1(user_api_client, classification): - url = ALL_LIST_URL + "?search=testi" - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert classification.uuid.hex in uuids - - -@pytest.mark.django_db -def test_classification_search_fuzzy2(user_api_client, classification): - url = ALL_LIST_URL + "?search=testisanojat" - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert classification.uuid.hex in uuids - - -@pytest.mark.django_db -def test_classification_search_query_string(user_api_client, classification_2): - url = ALL_LIST_URL + '?search_simple_query_string="testisana ja toinen testisana"' - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert classification_2.uuid.hex in uuids - - -@pytest.mark.django_db -def test_action_filter_attribute_exact(user_api_client, action): - url = ACTION_LIST_URL + "?action_AdditionalInformation=testisana" - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert action.uuid.hex in uuids +class TestClassificationSearch: + def test_classification_search_exact(self, user_api_client, classification): + url = ALL_LIST_URL + "?search=testisana" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert classification.uuid.hex in uuids + + def test_classification_search_fuzzy1(self, user_api_client, classification): + url = ALL_LIST_URL + "?search=testi" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert classification.uuid.hex in uuids + + def test_classification_search_fuzzy2(self, user_api_client, classification): + url = ALL_LIST_URL + "?search=testisanojat" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert classification.uuid.hex in uuids + + def test_classification_search_query_string( + self, user_api_client, classification_2 + ): + url = ( + ALL_LIST_URL + '?search_simple_query_string="testisana ja toinen testisana"' + ) + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert classification_2.uuid.hex in uuids @pytest.mark.django_db -def test_classification_filter_title_exact(user_api_client, classification): - url = CLASSIFICATION_LIST_URL + "?title=testisana" - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert classification.uuid.hex in uuids +class TestListFilters: + def test_classification_filter_title_exact(self, user_api_client, classification): + url = CLASSIFICATION_LIST_URL + "?title=testisana" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert classification.uuid.hex in uuids + + def test_action_filter_attribute_exact(self, user_api_client, action): + url = ACTION_LIST_URL + "?action_AdditionalInformation=testisana" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert action.uuid.hex in uuids + + def test_function_filter_attribute_exact(self, user_api_client, function): + url = FUNCTION_LIST_URL + "?function_AdditionalInformation=testisana" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert function.uuid.hex in uuids + + def test_phase_filter_attribute_exact(self, user_api_client, phase): + url = PHASE_LIST_URL + "?phase_AdditionalInformation=testisana" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert phase.uuid.hex in uuids + + def test_record_filter_attribute_exact(self, user_api_client, record, record_2): + assert Record.objects.count() == 2 + + url = RECORD_LIST_URL + "?record_AdditionalInformation=testisana" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert len(results) == 1 + assert record.uuid.hex in uuids + assert record_2.uuid.hex not in uuids @pytest.mark.django_db -def test_function_filter_attribute_exact(user_api_client, function): - url = FUNCTION_LIST_URL + "?function_AdditionalInformation=testisana" - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert function.uuid.hex in uuids +class TestFunctionAllSearchInformationSystem: + def test_not_match_for_unauthenticated( + self, api_client, function_with_information_system, phase + ): + phase.attributes = {"AdditionalInformation": "xyz"} + phase.save() + + response = api_client.get(ALL_LIST_URL + "?search=xyz") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + + assert function_with_information_system.uuid.hex not in uuids + assert phase.uuid.hex in uuids + + def test_information_system_not_visible_for_unauthenticated( + self, api_client, function_with_information_system, phase + ): + phase.attributes = {"AdditionalInformation": "testing"} + phase.save() + function_with_information_system.attributes = { + "InformationSystem": "xyz", + "AdditionalInformation": "testing", + } + function_with_information_system.save() + + response = api_client.get(ALL_LIST_URL + "?search=testing") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert function_with_information_system.uuid.hex in uuids + assert phase.uuid.hex in uuids + assert {"function_InformationSystem": "xyz"} not in attributes + + def test_information_system_is_visible_for_authenticated( + self, user_api_client, function_with_information_system, phase + ): + response = user_api_client.get(ALL_LIST_URL + "?search=xyz") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert function_with_information_system.uuid.hex in uuids + assert phase.uuid.hex not in uuids + assert {"function_InformationSystem": "xyz"} in attributes @pytest.mark.django_db -def test_phase_filter_attribute_exact(user_api_client, phase): - url = PHASE_LIST_URL + "?phase_AdditionalInformation=testisana" - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert phase.uuid.hex in uuids +class TestActionAllSearchInformationSystem: + def test_not_match_for_unauthenticated( + self, api_client, action_with_information_system, function_2 + ): + function_2.attributes = {"AdditionalInformation": "xyz"} + function_2.save() + + response = api_client.get(ALL_LIST_URL + "?search=xyz") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + + assert action_with_information_system.uuid.hex not in uuids + assert function_2.uuid.hex in uuids + + def test_information_system_not_visible_for_unauthenticated( + self, api_client, action_with_information_system, function_2 + ): + function_2.attributes = {"AdditionalInformation": "testing"} + function_2.save() + action_with_information_system.attributes = { + "InformationSystem": "xyz", + "AdditionalInformation": "testing", + } + action_with_information_system.save() + + response = api_client.get(ALL_LIST_URL + "?search=testing") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert action_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} not in attributes + + def test_information_system_is_visible_for_authenticated( + self, user_api_client, action_with_information_system, function_2 + ): + response = user_api_client.get(ALL_LIST_URL + "?search=xyz") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + assert action_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex not in uuids + assert {"action_InformationSystem": "xyz"} in attributes @pytest.mark.django_db -def test_record_filter_attribute_exact(user_api_client, record, record_2): - assert Record.objects.count() == 2 - - url = RECORD_LIST_URL + "?record_AdditionalInformation=testisana" - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert len(results) == 1 - assert record.uuid.hex in uuids - assert record_2.uuid.hex not in uuids +class TestRecordAllSearchInformationSystem: + def test_not_match_for_unauthenticated( + self, api_client, record_with_information_system, function_2 + ): + function_2.attributes = {"AdditionalInformation": "xyz"} + function_2.save() + + response = api_client.get(ALL_LIST_URL + "?search=xyz") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + + assert record_with_information_system.uuid.hex not in uuids + assert function_2.uuid.hex in uuids + + def test_information_system_not_visible_for_unauthenticated( + self, api_client, record_with_information_system, function_2 + ): + function_2.attributes = {"AdditionalInformation": "testing"} + function_2.save() + record_with_information_system.attributes = { + "InformationSystem": "xyz", + "AdditionalInformation": "testing", + } + record_with_information_system.save() + + response = api_client.get(ALL_LIST_URL + "?search=testing") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert record_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} not in attributes + + def test_information_system_is_visible_for_authenticated( + self, user_api_client, record_with_information_system, function_2 + ): + response = user_api_client.get(ALL_LIST_URL + "?search=xyz") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + assert record_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex not in uuids + assert {"record_InformationSystem": "xyz"} in attributes @pytest.mark.django_db -def test_record_filter_information_system_attribute_exact_filters_for_authenticated( - user_api_client, record_with_information_system, record_2 -): - assert Record.objects.count() == 2 - - url = RECORD_LIST_URL + "?record_InformationSystem=xyz" - response = user_api_client.get(url) - assert response.status_code == 200 - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert record_with_information_system.uuid.hex in uuids - assert record_2.uuid.hex not in uuids +class TestPhaseAllSearchInformationSystem: + def test_not_match_for_unauthenticated( + self, api_client, phase_with_information_system, function_2 + ): + function_2.attributes = {"AdditionalInformation": "xyz"} + function_2.save() + + response = api_client.get(ALL_LIST_URL + "?search=xyz") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + + assert phase_with_information_system.uuid.hex not in uuids + assert function_2.uuid.hex in uuids + + def test_information_system_not_visible_for_unauthenticated( + self, api_client, phase_with_information_system, function_2 + ): + function_2.attributes = {"AdditionalInformation": "testing"} + function_2.save() + phase_with_information_system.attributes = { + "InformationSystem": "xyz", + "AdditionalInformation": "testing", + } + phase_with_information_system.save() + + response = api_client.get(ALL_LIST_URL + "?search=testing") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert phase_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} not in attributes + + def test_information_system_is_visible_for_authenticated( + self, user_api_client, phase_with_information_system, function_2 + ): + response = user_api_client.get(ALL_LIST_URL + "?search=xyz") + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + assert phase_with_information_system.uuid.hex in uuids + assert function_2.uuid.hex not in uuids + assert {"phase_InformationSystem": "xyz"} in attributes @pytest.mark.django_db -def test_record_filter_information_system_attribute_exact_does_not_filter_for_unauthenticated( - record_with_information_system, record_2 -): - assert Record.objects.count() == 2 - - url = RECORD_LIST_URL + "?record_InformationSystem=xyz" - api_client = APIClient() - response = api_client.get(url) - - results = response.data["results"] if "results" in response.data else response.data - uuids = list(result["id"] for result in results) - assert response.status_code == 200 - assert record_with_information_system.uuid.hex in uuids - assert record_2.uuid.hex in uuids +class TestInformationSystemAttributeListUrls: + def test_action_information_system_attribute_for_authenticated( + self, user_api_client, action_with_information_system + ): + response = user_api_client.get(ACTION_LIST_URL) + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert response.status_code == 200 + assert action_with_information_system.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} in attributes + + def test_action_does_not_show_information_system_attribute_for_unauthenticated( + self, api_client, action_with_information_system + ): + response = api_client.get(ACTION_LIST_URL) + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert response.status_code == 200 + assert action_with_information_system.uuid.hex in uuids + assert {"action_InformationSystem": "xyz"} not in attributes + + def test_record_filter_information_system_attribute_exact_filters_for_authenticated( + self, user_api_client, record_with_information_system, record_2 + ): + assert Record.objects.count() == 2 + + url = RECORD_LIST_URL + "?record_InformationSystem=xyz" + response = user_api_client.get(url) + assert response.status_code == 200 + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert record_with_information_system.uuid.hex in uuids + assert record_2.uuid.hex not in uuids + + def test_record_does_shows_information_system_attribute_for_authenticated( + self, user_api_client, record_with_information_system, record_2 + ): + response = user_api_client.get(RECORD_LIST_URL) + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert response.status_code == 200 + assert record_with_information_system.uuid.hex in uuids + assert {"record_InformationSystem": "xyz"} in attributes + assert record_2.uuid.hex in uuids + + def test_record_does_not_show_information_system_attribute_for_unauthenticated( + self, api_client, record_with_information_system, record_2 + ): + response = api_client.get(RECORD_LIST_URL) + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + attributes = list(result["attributes"] for result in results) + + assert response.status_code == 200 + assert record_with_information_system.uuid.hex in uuids + assert {"record_InformationSystem": "xyz"} not in attributes + assert record_2.uuid.hex in uuids + + def test_record_filter_information_system_attribute_exact_does_not_filter_for_unauthenticated( + self, api_client, record_with_information_system, record_2 + ): + assert Record.objects.count() == 2 + + url = RECORD_LIST_URL + "?record_InformationSystem=xyz" + response = api_client.get(url) + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert response.status_code == 200 + assert record_with_information_system.uuid.hex in uuids + assert record_2.uuid.hex in uuids + + def test_phase_filter_information_system_attribute_exact_for_authenticated( + self, user_api_client, phase_with_information_system, phase_2 + ): + assert Phase.objects.count() == 2 + url = PHASE_LIST_URL + "?phase_InformationSystem=xyz" + response = user_api_client.get(url) + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert response.status_code == 200 + assert phase_with_information_system.uuid.hex in uuids + assert phase_2.uuid.hex not in uuids + + def test_phase_filter_information_system_attribute_exact_does_not_filter_for_unauthenticated( + self, api_client, phase_with_information_system, phase_2 + ): + assert Phase.objects.count() == 2 + + url = PHASE_LIST_URL + "?phase_InformationSystem=xyz" + response = api_client.get(url) + + results = ( + response.data["results"] if "results" in response.data else response.data + ) + uuids = list(result["id"] for result in results) + assert response.status_code == 200 + assert phase_with_information_system.uuid.hex in uuids + assert phase_2.uuid.hex in uuids diff --git a/search_indices/views/base.py b/search_indices/views/base.py index aa1d5f52..16a585e6 100644 --- a/search_indices/views/base.py +++ b/search_indices/views/base.py @@ -8,6 +8,7 @@ from metarecord.pagination import ESRecordPagination from search_indices.backends.faceted_attribute_backend import FacetedAttributeBackend +from search_indices.serializers.utils import attributes_for_authenticated from search_indices.views.utils import populate_filter_fields_with_attributes @@ -68,10 +69,21 @@ class BaseSearchDocumentViewSet(BaseDocumentViewSet): "_score", ) - def filter_queryset(self, queryset): - # Restrict querying information system queries to authenticated users. - # The information system field contents are not public. - if not self.request.user.is_authenticated: - self.filter_fields.pop("record_InformationSystem", None) + def _filter_search_fields_for_unauthenticated(self): + search_fields = [] + for field in self.search_fields: + if "InformationSystem" in field: + continue + search_fields.append(field) + self.search_fields = tuple(search_fields) - return super().filter_queryset(queryset) + for attribute in attributes_for_authenticated: + self.filter_fields.pop(attribute, None) + + def initial(self, request, *args, **kwargs): + if not request.user.is_authenticated: + # Restrict querying information system queries to authenticated users. + # The information system field contents are not public. + self._filter_search_fields_for_unauthenticated() + + super().initial(request, *args, **kwargs)