diff --git a/lib/dl_api_lib/dl_api_lib_tests/db/control_api/test_rls.py b/lib/dl_api_lib/dl_api_lib_tests/db/control_api/test_rls.py index 9be5cac80..d3110e453 100644 --- a/lib/dl_api_lib/dl_api_lib_tests/db/control_api/test_rls.py +++ b/lib/dl_api_lib/dl_api_lib_tests/db/control_api/test_rls.py @@ -1,5 +1,14 @@ import pytest +from dl_api_commons.base_models import RequestContextInfo +from dl_api_lib.dataset.view import DatasetView +from dl_api_lib.query.formalization.block_formalizer import BlockFormalizer +from dl_api_lib.query.formalization.legend_formalizer import ResultLegendFormalizer +from dl_api_lib.query.formalization.raw_specs import ( + IdFieldRef, + RawQuerySpecUnion, + RawSelectFieldSpec, +) from dl_api_lib_testing.rls import ( RLS_CONFIG_CASES, config_to_comparable, @@ -59,3 +68,34 @@ def test_create_rls_from_invalid_config(self, control_api, saved_dataset): assert rls_resp.bi_status_code == "ERR.DS_API.RLS.PARSE" assert rls_resp.json["message"] == "RLS: Parsing failed at line 2" assert rls_resp.json["details"] == {"description": "Wrong format"} + + def test_rls_filter_expr(self, control_api, saved_dataset, sync_us_manager): + config = load_rls_config("dl_api_lib_test_config") + field_a, field_b = saved_dataset.result_schema[0].id, saved_dataset.result_schema[1].id + saved_dataset.rls = {field_a: config, field_b: config} + control_api.save_dataset(saved_dataset, fail_ok=False) + + ds = sync_us_manager.get_by_id(saved_dataset.id) + sync_us_manager.load_dependencies(ds) + + rci = RequestContextInfo(user_id="user1") + raw_query_spec_union = RawQuerySpecUnion( + select_specs=[ + RawSelectFieldSpec(ref=IdFieldRef(id=field_a)), + RawSelectFieldSpec(ref=IdFieldRef(id=field_b)), + ], + ) + legend = ResultLegendFormalizer(dataset=ds).make_legend(raw_query_spec_union=raw_query_spec_union) + block_legend = BlockFormalizer(dataset=ds).make_block_legend( + raw_query_spec_union=raw_query_spec_union, legend=legend + ) + ds_view = DatasetView( + ds, + us_manager=sync_us_manager, + block_spec=block_legend.blocks[0], + rci=rci, + ) + + exec_info = ds_view.build_exec_info() + src_query = next(iter(exec_info.translated_multi_query.iter_queries())) + assert len(src_query.where) == 2 # field_a in ... and field_b in ... diff --git a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/test_rls.py b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/test_rls.py index 90e5fcf2e..b6a87f4a4 100644 --- a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/test_rls.py +++ b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/test_rls.py @@ -1,7 +1,10 @@ +import re + import pytest +from dl_api_client.dsmaker.shortcuts.result_data import get_data_rows from dl_api_lib_testing.app import TestingSubjectResolver -from dl_api_lib_testing.rls import MAIN_TEST_CASE +from dl_api_lib_testing.rls import load_rls_config from dl_api_lib_tests.db.base import DefaultApiTestBase @@ -9,8 +12,8 @@ class TestRLS(DefaultApiTestBase): @pytest.fixture(scope="function") def dataset_with_rls(self, control_api, saved_dataset): ds = saved_dataset - field_guid = ds.result_schema[0].id - ds.rls = {field_guid: MAIN_TEST_CASE["config"]} + field_guid = ds.result_schema[1].id + ds.rls = {field_guid: load_rls_config("dl_api_lib_test_config")} control_api.save_dataset(ds, fail_ok=False) resp = control_api.load_dataset(ds) @@ -25,13 +28,16 @@ def get_subjects_by_names_mock(self, names): rls_val_modifier = "\n'x': *\n" if modify_rls else "" ds.rls = {key: val + rls_val_modifier for key, val in ds.rls.items() if val} monkeypatch.setattr(TestingSubjectResolver, "get_subjects_by_names", get_subjects_by_names_mock) - return data_api.get_preview(dataset=ds, limit=13, fail_ok=True) + return data_api.get_preview(dataset=ds, fail_ok=True) def test_preview_with_saved_rls(self, dataset_with_rls, data_api, monkeypatch): resp = self._get_rls_preview_response(dataset_with_rls, data_api, monkeypatch, modify_rls=False) - resp_data = resp.json - assert resp.status_code == 200, resp_data - assert resp_data + assert resp.status_code == 200, resp.json + + rls_data = [row[1] for row in get_data_rows(resp)] + assert rls_data + # ensure all values from the RLS config are presented and no other values are + assert set(rls_data) == set(re.findall("'(.+?)'", load_rls_config("dl_api_lib_test_config"))) def test_preview_with_updated_rls(self, dataset_with_rls, data_api, monkeypatch): resp = self._get_rls_preview_response(dataset_with_rls, data_api, monkeypatch, modify_rls=True) diff --git a/lib/dl_api_lib_testing/dl_api_lib_testing/app.py b/lib/dl_api_lib_testing/dl_api_lib_testing/app.py index 9daa920f7..8043eb680 100644 --- a/lib/dl_api_lib_testing/dl_api_lib_testing/app.py +++ b/lib/dl_api_lib_testing/dl_api_lib_testing/app.py @@ -105,15 +105,29 @@ def rqe_config_subprocess_cm(self) -> Generator[RQEConfig, None, None]: @attr.s class TestingSubjectResolver(BaseSubjectResolver): def get_subjects_by_names(self, names: list[str]) -> list[RLSSubject]: - """Mock resolver. Considers a user real if his name starts with 'user'""" - return [ - RLSSubject( - subject_id="", - subject_type=RLSSubjectType.user if name.startswith("user") else RLSSubjectType.notfound, - subject_name=name if name.startswith("user") else RLS_FAILED_USER_NAME_PREFIX + name, - ) - for name in names - ] + """ + Mock resolver. Considers a user real if the name starts with a 'user' or + if it's equals to '_the_tests_asyncapp_user_name_' + """ + subjects = [] + for name in names: + if name == "_the_tests_asyncapp_user_name_": + subjects.append( + RLSSubject( + subject_id="_the_tests_asyncapp_user_id_", + subject_type=RLSSubjectType.user, + subject_name=name, + ) + ) + else: + subjects.append( + RLSSubject( + subject_id="", + subject_type=RLSSubjectType.user if name.startswith("user") else RLSSubjectType.notfound, + subject_name=name if name.startswith("user") else RLS_FAILED_USER_NAME_PREFIX + name, + ) + ) + return subjects @attr.s diff --git a/lib/dl_api_lib_testing/dl_api_lib_testing/rls.py b/lib/dl_api_lib_testing/dl_api_lib_testing/rls.py index 00b77206e..ff264406f 100644 --- a/lib/dl_api_lib_testing/dl_api_lib_testing/rls.py +++ b/lib/dl_api_lib_testing/dl_api_lib_testing/rls.py @@ -62,7 +62,6 @@ def load_rls(name: str) -> list[RLSEntry]: rls_entries_updated=load_rls("missing_login_updated.json"), ), ] -MAIN_TEST_CASE = RLS_CONFIG_CASES[0] def config_to_comparable(conf: str): diff --git a/lib/dl_api_lib_testing/dl_api_lib_testing/test_data/rls_configs/dl_api_lib_test_config b/lib/dl_api_lib_testing/dl_api_lib_testing/test_data/rls_configs/dl_api_lib_test_config new file mode 100644 index 000000000..274ed7279 --- /dev/null +++ b/lib/dl_api_lib_testing/dl_api_lib_testing/test_data/rls_configs/dl_api_lib_test_config @@ -0,0 +1,2 @@ +'Naperville': _the_tests_asyncapp_user_name_ +'Philadelphia': *