diff --git a/lib/dl_core_testing/dl_core_testing/testcases/dataset.py b/lib/dl_core_testing/dl_core_testing/testcases/dataset.py index 5132168e6..c3ed9f167 100644 --- a/lib/dl_core_testing/dl_core_testing/testcases/dataset.py +++ b/lib/dl_core_testing/dl_core_testing/testcases/dataset.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import ( + AbstractSet, ClassVar, Generic, Optional, @@ -11,10 +12,13 @@ import sqlalchemy as sa from dl_constants.enums import ( + ConnectionType, DataSourceRole, DataSourceType, + JoinType, ) from dl_core.data_processing.stream_base import DataStream +from dl_core.dataset_capabilities import DatasetCapabilities from dl_core.query.bi_query import BIQuery from dl_core.query.expression import ExpressionCtx from dl_core.services_registry.top_level import ServicesRegistry @@ -22,8 +26,14 @@ from dl_core.us_dataset import Dataset from dl_core.us_manager.us_manager_sync import SyncUSManager from dl_core_testing.data import DataFetcher -from dl_core_testing.database import DbTable -from dl_core_testing.dataset import make_dataset +from dl_core_testing.database import ( + Db, + DbTable, +) +from dl_core_testing.dataset import ( + get_created_from, + make_dataset, +) from dl_core_testing.dataset_wrappers import ( DatasetTestWrapper, EditableDatasetTestWrapper, @@ -202,3 +212,42 @@ def test_get_param_hash( assert hash_from_dataset == hash_from_template assert found_template + + def _check_compatible_source_types(self, compat_source_types: AbstractSet[DataSourceType]) -> None: + assert self.source_type in compat_source_types + + def _check_compatible_connection_types(self, compat_conn_types: AbstractSet[ConnectionType]) -> None: + assert not compat_conn_types, "Multiple connections are not supported" + + def _check_supported_join_types(self, supp_join_types: AbstractSet[JoinType]) -> None: + assert set(supp_join_types).issuperset({JoinType.inner, JoinType.left}) + + def _allow_adding_sources(self, dataset: Dataset) -> bool: + return True + + def test_compatibility_info( + self, + db: Db, + saved_connection: ConnectionBase, + saved_dataset: Dataset, + sync_us_manager: SyncUSManager, + ) -> None: + dataset = saved_dataset + connection = saved_connection + + ds_wrapper = DatasetTestWrapper(dataset=dataset, us_manager=sync_us_manager) + capabilities = DatasetCapabilities(dataset=dataset, dsrc_coll_factory=ds_wrapper.dsrc_coll_factory) + + compat_source_types = capabilities.get_compatible_source_types() + self._check_compatible_source_types(compat_source_types) + + compat_conn_types = capabilities.get_compatible_connection_types() + self._check_compatible_connection_types(compat_conn_types) + + supp_join_types = capabilities.get_supported_join_types() + self._check_supported_join_types(supp_join_types) + + assert capabilities.get_effective_connection_id() == connection.uuid + + if self._allow_adding_sources(dataset=dataset): + assert capabilities.source_can_be_added(connection_id=connection.uuid, created_from=get_created_from(db=db))