diff --git a/collect_coordinator/worker_queue.py b/collect_coordinator/worker_queue.py index 58d3073..acebd05 100644 --- a/collect_coordinator/worker_queue.py +++ b/collect_coordinator/worker_queue.py @@ -184,6 +184,7 @@ def handle_azure_subscription() -> None: "accounts": { "default": { "subscriptions": [az_subscription_id], + "collect_microsoft_graph": account.get("collect_microsoft_graph", False), "client_secret": { "tenant_id": az_tenant_id, "client_id": az_client_id, diff --git a/tests/worker_queue_test.py b/tests/worker_queue_test.py index 8239a59..e6f8d87 100644 --- a/tests/worker_queue_test.py +++ b/tests/worker_queue_test.py @@ -37,7 +37,7 @@ @fixture -def example_collect_definition() -> Json: +def example_aws_collect_definition() -> Json: return { "job_id": "uid", "tenant_id": "a", @@ -56,6 +56,27 @@ def example_collect_definition() -> Json: } +@fixture +def example_azure_collect_definition() -> Json: + return { + "job_id": "uid", + "tenant_id": "a", + "graphdb_server": "b", + "graphdb_database": "c", + "graphdb_username": "d", + "graphdb_password": "e", + "account": { + "kind": "azure_subscription_information", + "azure_subscription_id": "123", + "tenant_id": "123", + "client_id": "234", + "client_secret": "bombproof", + "collect_microsoft_graph": False, + }, + "env": {"test": "test"}, + } + + @fixture def example_post_collect_definition() -> Json: return { @@ -75,8 +96,20 @@ def example_post_collect_definition() -> Json: @pytest.mark.skipif(os.environ.get("REDIS_RUNNING", "false") != "true", reason="Redis not running") -def test_read_job_definition(worker_queue: WorkerQueue, example_collect_definition: Json) -> None: - job_def = worker_queue.parse_collect_definition_json(example_collect_definition) +def test_read_azure_job_definition(worker_queue: WorkerQueue, example_azure_collect_definition: Json) -> None: + job_def = worker_queue.parse_collect_definition_json(example_azure_collect_definition) + assert ( + job_def.env + and job_def.env["WORKER_CONFIG"] + == '{"azure": {"accounts": {"default": {"subscriptions": ["123"], "collect_microsoft_graph": false, ' + '"client_secret": {"tenant_id": "123", "client_id": "234", "client_secret": "bombproof"}}}}, ' + '"fixworker": {"collector": ["azure"]}}' + ) + + +@pytest.mark.skipif(os.environ.get("REDIS_RUNNING", "false") != "true", reason="Redis not running") +def test_read_aws_job_definition(worker_queue: WorkerQueue, example_aws_collect_definition: Json) -> None: + job_def = worker_queue.parse_collect_definition_json(example_aws_collect_definition) assert job_def.name.startswith("collect") assert job_def.image == "someengineering/fix-collect-single:0.0.1" # fmt: off @@ -149,12 +182,15 @@ def test_read_post_job_definition(worker_queue: WorkerQueue, example_post_collec @pytest.mark.asyncio @pytest.mark.skipif(os.environ.get("REDIS_RUNNING", "false") != "true", reason="Redis not running") async def test_enqueue_jobs( - arq_redis: ArqRedis, worker_queue: WorkerQueue, coordinator: LazyJobCoordinator, example_collect_definition: Json + arq_redis: ArqRedis, + worker_queue: WorkerQueue, + coordinator: LazyJobCoordinator, + example_aws_collect_definition: Json, ) -> None: async with worker_queue: - await arq_redis.enqueue_job("collect", example_collect_definition) - await arq_redis.enqueue_job("collect", example_collect_definition) - await arq_redis.enqueue_job("collect", example_collect_definition) + await arq_redis.enqueue_job("collect", example_aws_collect_definition) + await arq_redis.enqueue_job("collect", example_aws_collect_definition) + await arq_redis.enqueue_job("collect", example_aws_collect_definition) ping = await arq_redis.enqueue_job("ping") assert ping is not None