diff --git a/test/integration_test.py b/test/integration_test.py index ab2b38507e..7673d86a67 100644 --- a/test/integration_test.py +++ b/test/integration_test.py @@ -3,7 +3,6 @@ ) from collections.abc import ( Iterable, - Iterator, Mapping, Sequence, ) @@ -19,10 +18,8 @@ BytesIO, TextIOWrapper, ) -import itertools from itertools import ( count, - starmap, ) import json import os @@ -78,6 +75,7 @@ first, grouper, one, + only, ) from openapi_spec_validator import ( validate_spec, @@ -105,6 +103,9 @@ from azul.chalice import ( AzulChaliceApp, ) +from azul.collections import ( + alist, +) from azul.drs import ( AccessMethod, ) @@ -115,6 +116,7 @@ http_client, ) from azul.indexer import ( + Prefix, SourceJSON, SourceRef, SourcedBundleFQID, @@ -151,7 +153,6 @@ ) from azul.plugins.repository.tdr_anvil import ( BundleType, - TDRAnvilBundleFQID, TDRAnvilBundleFQIDJSON, ) from azul.portal_service import ( @@ -285,73 +286,28 @@ def managed_access_sources_by_catalog(self managed_access_sources[catalog].add(ref) return managed_access_sources - def _list_partitions(self, - catalog: CatalogName, - *, - min_bundles: int, - public_1st: bool - ) -> Iterator[tuple[SourceRef, str, list[SourcedBundleFQID]]]: - """ - Iterate through the sources in the given catalog and yield partitions of - bundle FQIDs until a desired minimum number of bundles are found. For - each emitted source, every partition is included, even if it's empty. - """ - total_bundles = 0 - sources = sorted(config.sources(catalog)) - self.random.shuffle(sources) - if public_1st: - managed_access_sources = frozenset( - str(source.spec) - for source in self.managed_access_sources_by_catalog[catalog] - ) - index = first( - i - for i, source in enumerate(sources) - if source not in managed_access_sources - ) - sources[0], sources[index] = sources[index], sources[0] - plugin = self.azul_client.repository_plugin(catalog) - # This iteration prefers sources occurring first, so we shuffle them - # above to neutralize the bias. - for source in sources: - source = plugin.resolve_source(source) - source = plugin.partition_source(catalog, source) - for prefix in source.spec.prefix.partition_prefixes(): - new_fqids = self.azul_client.list_bundles(catalog, source, prefix) - total_bundles += len(new_fqids) - yield source, prefix, new_fqids - # We postpone this check until after we've yielded all partitions in - # the current source to ensure test coverage for handling multiple - # partitions per source - if total_bundles >= min_bundles: - break + def _choose_source(self, + catalog: CatalogName, + *, + public: bool | None = None + ) -> SourceRef | None: + plugin = self.repository_plugin(catalog) + sources = set(config.sources(catalog)) + managed_access_sources = { + str(source.spec) + for source in self.managed_access_sources_by_catalog[catalog] + } + self.assertIsSubset(managed_access_sources, sources) + if public is True: + sources -= managed_access_sources + elif public is False: + sources &= managed_access_sources + if len(sources) == 0: + assert public is False + return None else: - log.warning('Checked all sources and found only %d bundles instead of the ' - 'expected minimum %d', total_bundles, min_bundles) - - def _list_managed_access_bundles(self, - catalog: CatalogName - ) -> Iterator[tuple[SourceRef, str, list[SourcedBundleFQID]]]: - sources = self.azul_client.catalog_sources(catalog) - # We need at least one managed_access bundle per IT. To index them with - # remote_reindex and avoid collateral bundles, we use as specific a - # prefix as possible. - for source in self.managed_access_sources_by_catalog[catalog]: - assert str(source.spec) in sources - source = self.repository_plugin(catalog).partition_source(catalog, source) - bundle_fqids = sorted( - bundle_fqid - for bundle_fqid in self.azul_client.list_bundles(catalog, source, prefix='') - if not ( - # DUOS bundles are too sparse to fulfill the managed access tests - config.is_anvil_enabled(catalog) - and cast(TDRAnvilBundleFQID, bundle_fqid).table_name is BundleType.duos - ) - ) - bundle_fqid = self.random.choice(bundle_fqids) - prefix = bundle_fqid.uuid[:8] - new_fqids = self.azul_client.list_bundles(catalog, source, prefix) - yield source, prefix, new_fqids + source = self.random.choice(sorted(sources)) + return plugin.resolve_source(source) class IndexingIntegrationTest(IntegrationTestCase, AlwaysTearDownTestCase): @@ -427,6 +383,8 @@ class Catalog: name: CatalogName bundles: set[SourcedBundleFQID] notifications: list[JSON] + public_source: SourceRef | None + ma_source: SourceRef | None def _wait_for_indexer(): self.azul_client.wait_for_indexer() @@ -442,13 +400,16 @@ def _wait_for_indexer(): catalogs: list[Catalog] = [] for catalog in config.integration_test_catalogs: - if index: - notifications, fqids = self._prepare_notifications(catalog) - else: - notifications, fqids = [], set() + public_source, ma_source, notifications, fqids = ( + self._prepare_notifications(catalog) + if index else + (None, None, [], set()) + ) catalogs.append(Catalog(name=catalog, bundles=fqids, - notifications=notifications)) + notifications=notifications, + public_source=public_source, + ma_source=ma_source)) if index: for catalog in catalogs: @@ -465,11 +426,20 @@ def _wait_for_indexer(): self._test_dos_and_drs(catalog.name) self._test_repository_files(catalog.name) if index: + public_source = catalog.public_source + ma_source = catalog.ma_source bundle_fqids = catalog.bundles else: with self._service_account_credentials: bundle_fqids = self._get_indexed_bundles(catalog.name) - self._test_managed_access(catalog=catalog.name, bundle_fqids=bundle_fqids) + indexed_sources = {fqid.source for fqid in bundle_fqids} + ma_sources = self.managed_access_sources_by_catalog[catalog.name] + public_source = one((s for s in indexed_sources if s not in ma_sources)) + ma_source = only((s for s in indexed_sources if s in ma_sources)) + + self._test_managed_access(catalog=catalog.name, + public_source=public_source, + ma_source=ma_source) if index and delete: # FIXME: Test delete notifications @@ -1221,33 +1191,22 @@ def _validate_fastq_content(self, content: ReadableFileObject): self.assertTrue(lines[2].startswith(b'+')) def _prepare_notifications(self, - catalog: CatalogName - ) -> tuple[JSONs, set[SourcedBundleFQID]]: - bundle_fqids: set[SourcedBundleFQID] = set() + catalog: CatalogName, + ) -> tuple[SourceRef, SourceRef | None, JSONs, set[SourcedBundleFQID]]: + public_source = self._choose_source(catalog, public=True) + ma_source = self._choose_source(catalog, public=False) + plugin = self.repository_plugin(catalog) + bundle_fqids = set() notifications = [] - - def update(source: SourceRef, - prefix: str, - partition_bundle_fqids: Iterable[SourcedBundleFQID]): - bundle_fqids.update(partition_bundle_fqids) - notifications.append(self.azul_client.reindex_message(catalog, - source, - prefix)) - - list(starmap(update, self._list_managed_access_bundles(catalog))) - num_bundles = max(self.min_bundles - len(bundle_fqids), 1) - log.info('Selected %d bundles to satisfy managed access coverage; ' - 'selecting at least %d more', len(bundle_fqids), num_bundles) - # _list_partitions selects both public and managed access sources at random. - # If we don't index at least one public source, every request would need - # service account credentials and we couldn't compare the responses for - # public and managed access data. `public_1st` ensures that at least - # one of the sources will be public because sources are indexed starting - # with the first one yielded by the iteration. - list(starmap(update, self._list_partitions(catalog, - min_bundles=num_bundles, - public_1st=True))) - + for source in alist(public_source, ma_source): + source = plugin.partition_source(catalog, source) + # Some partitions may be empty, but we include them anyway to + # ensure test coverage for handling multiple partitions per source + for partition_prefix in source.spec.prefix.partition_prefixes(): + bundle_fqids.update(self.azul_client.list_bundles(catalog, source, partition_prefix)) + notifications.append(self.azul_client.reindex_message(catalog, + source, + partition_prefix)) # Index some bundles again to test that we handle duplicate additions. # Note: random.choices() may pick the same element multiple times so # some notifications may end up being sent three or more times. @@ -1257,7 +1216,7 @@ def update(source: SourceRef, for bundle in self.random.choices(sorted(bundle_fqids), k=num_duplicates) ] notifications.extend(duplicate_bundles) - return notifications, bundle_fqids + return public_source, ma_source, notifications, bundle_fqids def _get_indexed_bundles(self, catalog: CatalogName, @@ -1386,40 +1345,34 @@ def _assert_indices_exist(self, catalog: CatalogName): def _test_managed_access(self, catalog: CatalogName, - bundle_fqids: set[SourcedBundleFQID] + public_source: SourceRef, + ma_source: SourceRef | None, ) -> None: with self.subTest('managed_access', catalog=catalog): - indexed_source_ids = {fqid.source.id for fqid in bundle_fqids} - managed_access_sources = self.managed_access_sources_by_catalog[catalog] - managed_access_source_ids = {source.id for source in managed_access_sources} - self.assertIsSubset(managed_access_source_ids, indexed_source_ids) - - if not managed_access_sources: + if ma_source is None: if config.deployment_stage in ('dev', 'sandbox'): # There should always be at least one managed-access source # indexed and tested on the default catalog for these deployments self.assertNotEqual(catalog, config.it_catalog_for(config.default_catalog)) self.skipTest(f'No managed access sources found in catalog {catalog!r}') - with self.subTest('managed_access_indices', catalog=catalog): - self._test_managed_access_indices(catalog, managed_access_source_ids) + self._test_managed_access_indices(catalog, public_source, ma_source) with self.subTest('managed_access_repository_files', catalog=catalog): - files = self._test_managed_access_repository_files(catalog, managed_access_source_ids) + files = self._test_managed_access_repository_files(catalog, ma_source) with self.subTest('managed_access_summary', catalog=catalog): self._test_managed_access_summary(catalog, files) with self.subTest('managed_access_repository_sources', catalog=catalog): - public_source_ids = self._test_managed_access_repository_sources(catalog, - indexed_source_ids, - managed_access_source_ids) - with self.subTest('managed_access_manifest', catalog=catalog): - source_id = self.random.choice(sorted(public_source_ids & indexed_source_ids)) - self._test_managed_access_manifest(catalog, files, source_id) + self._test_managed_access_repository_sources(catalog, + public_source, + ma_source) + with self.subTest('managed_access_manifest', catalog=catalog): + self._test_managed_access_manifest(catalog, files, public_source) def _test_managed_access_repository_sources(self, catalog: CatalogName, - indexed_source_ids: set[str], - managed_access_source_ids: set[str] - ) -> set[str]: + public_source: SourceRef, + ma_source: SourceRef + ) -> None: """ Test the managed access controls for the /repository/sources endpoint :return: the set of public sources @@ -1432,7 +1385,7 @@ def list_source_ids() -> set[str]: return {source['sourceId'] for source in cast(JSONs, response['sources'])} with self._service_account_credentials: - self.assertIsSubset(indexed_source_ids, list_source_ids()) + self.assertIsSubset({public_source.id, ma_source.id}, list_source_ids()) with self._public_service_account_credentials: public_source_ids = list_source_ids() with self._unregistered_service_account_credentials: @@ -1444,13 +1397,13 @@ def list_source_ids() -> set[str]: invalid_client = OAuth2Client(credentials_provider=invalid_provider) with self._authorization_context(invalid_client): self.assertEqual(401, self._get_url_unchecked(GET, url).status) - self.assertEqual(set(), list_source_ids() & managed_access_source_ids) + self.assertEqual(set(), list_source_ids() & {ma_source.id}) self.assertEqual(public_source_ids, list_source_ids()) - return public_source_ids def _test_managed_access_indices(self, catalog: CatalogName, - managed_access_source_ids: set[str] + public_source: SourceRef, + ma_source: SourceRef ) -> JSONs: """ Test the managed-access controls for the /index/bundles and @@ -1460,11 +1413,6 @@ def _test_managed_access_indices(self, """ special_fields = self.metadata_plugin(catalog).special_fields - - def source_id_from_hit(hit: JSON) -> str: - sources: JSONs = hit['sources'] - return one(sources)[special_fields.source_id] - bundle_type = self._bundle_type(catalog) project_type = self._project_type(catalog) @@ -1477,31 +1425,22 @@ def source_id_from_hit(hit: JSON) -> str: hits = self._get_entities(catalog, project_type, filters=filters) if accessible is None: unfiltered_hits = hits - accessible_sources, inaccessible_sources = set(), set() for hit in hits: - source_id = source_id_from_hit(hit) - source_accessible = source_id not in managed_access_source_ids + source_id = one(hit['sources'])[special_fields.source_id] + source_accessible = {public_source.id: True, ma_source.id: False}[source_id] hit_accessible = one(hit[project_type])[special_fields.accessible] self.assertEqual(source_accessible, hit_accessible, hit['entryId']) if accessible is not None: self.assertEqual(accessible, hit_accessible) - if source_accessible: - accessible_sources.add(source_id) - else: - inaccessible_sources.add(source_id) - self.assertIsDisjoint(accessible_sources, inaccessible_sources) - self.assertIsDisjoint(managed_access_source_ids, accessible_sources) - self.assertEqual(set() if accessible else managed_access_source_ids, - inaccessible_sources) self.assertIsNotNone(unfiltered_hits, 'Cannot recover from subtest failure') bundle_fqids = self._get_indexed_bundles(catalog) hit_source_ids = {fqid.source.id for fqid in bundle_fqids} - self.assertEqual(set(), hit_source_ids & managed_access_source_ids) + self.assertEqual(hit_source_ids, {public_source.id}) source_filter = { special_fields.source_id: { - 'is': list(managed_access_source_ids) + 'is': [ma_source.id] } } params = { @@ -1510,18 +1449,18 @@ def source_id_from_hit(hit: JSON) -> str: } url = config.service_endpoint.set(path=('index', bundle_type), args=params) response = self._get_url_unchecked(GET, url) - self.assertEqual(403 if managed_access_source_ids else 200, response.status) + self.assertEqual(403, response.status) with self._service_account_credentials: bundle_fqids = self._get_indexed_bundles(catalog, filters=source_filter) hit_source_ids = {fqid.source.id for fqid in bundle_fqids} - self.assertEqual(managed_access_source_ids, hit_source_ids) + self.assertEqual({ma_source.id}, hit_source_ids) return unfiltered_hits def _test_managed_access_repository_files(self, catalog: CatalogName, - managed_access_source_ids: set[str] + ma_source: SourceRef ) -> JSONs: """ Test the managed access controls for the /repository/files endpoint @@ -1531,7 +1470,7 @@ def _test_managed_access_repository_files(self, with self._service_account_credentials: files = self._get_entities(catalog, 'files', filters={ special_fields.source_id: { - 'is': list(managed_access_source_ids) + 'is': [ma_source.id] } }) managed_access_file_urls = { @@ -1568,7 +1507,7 @@ def _get_summary_file_count() -> int: def _test_managed_access_manifest(self, catalog: CatalogName, files: JSONs, - source_id: str + public_source: SourceRef ) -> None: """ Test the managed access controls for the /manifest/files endpoint and @@ -1590,7 +1529,7 @@ def bundle_uuids(hit: JSON) -> set[str]: for file in files if len(file['sources']) == 1 )) - filters = {special_fields.source_id: {'is': [source_id]}} + filters = {special_fields.source_id: {'is': [public_source.id]}} params = {'size': 1, 'catalog': catalog, 'filters': json.dumps(filters)} files_url = furl(url=endpoint, path='index/files', args=params) response = self._get_url_json(GET, files_url) @@ -1958,13 +1897,10 @@ def test_can_bundle_canned_repository(self): self._test_catalog(mock_catalog) def bundle_fqid(self, catalog: CatalogName) -> SourcedBundleFQID: - # Skip through empty partitions - bundle_fqids = itertools.chain.from_iterable( - bundle_fqids - for _, _, bundle_fqids in self._list_partitions(catalog, - min_bundles=1, - public_1st=False) - ) + source = self._choose_source(catalog) + # The plugin will raise an exception if the source lacks a prefix + source = source.with_prefix(Prefix.of_everything) + bundle_fqids = self.repository_plugin(catalog).list_bundles(source, '') return self.random.choice(sorted(bundle_fqids)) def _can_bundle(self,