diff --git a/src/mavedb/routers/score_sets.py b/src/mavedb/routers/score_sets.py index 44fd2a30..b5fb0dee 100644 --- a/src/mavedb/routers/score_sets.py +++ b/src/mavedb/routers/score_sets.py @@ -632,7 +632,7 @@ async def upload_score_set_variant_data( job = await worker.enqueue_job( "create_variants_for_score_set", correlation_id_for_context(), - item.urn, + item.id, user_data.user.id, scores_df, counts_df, @@ -871,7 +871,7 @@ async def update_score_set( job = await worker.enqueue_job( "create_variants_for_score_set", correlation_id_for_context(), - item.urn, + item.id, user_data.user.id, scores_data, count_data, diff --git a/src/mavedb/worker/jobs.py b/src/mavedb/worker/jobs.py index 66a6e82d..39423648 100644 --- a/src/mavedb/worker/jobs.py +++ b/src/mavedb/worker/jobs.py @@ -54,7 +54,7 @@ async def mapping_in_execution(redis: ArqRedis, job_id: str): await redis.set(MAPPING_CURRENT_ID_NAME, "") -def setup_job_state(ctx, invoker: int, resource: str, correlation_id: str): +def setup_job_state(ctx, invoker: int, resource: Optional[str], correlation_id: str): ctx["state"][ctx["job_id"]] = { "application": "mavedb-worker", "user": invoker, @@ -90,7 +90,7 @@ async def enqueue_job_with_backoff( async def create_variants_for_score_set( - ctx, correlation_id: str, score_set_urn: str, updater_id: int, scores: pd.DataFrame, counts: pd.DataFrame + ctx, correlation_id: str, score_set_id: int, updater_id: int, scores: pd.DataFrame, counts: pd.DataFrame ): """ Create variants for a score set. Intended to be run within a worker. @@ -99,14 +99,14 @@ async def create_variants_for_score_set( """ logging_context = {} try: - logging_context = setup_job_state(ctx, updater_id, score_set_urn, correlation_id) - logger.info(msg="Began processing of score set variants.", extra=logging_context) - db: Session = ctx["db"] hdp: RESTDataProvider = ctx["hdp"] redis: ArqRedis = ctx["redis"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() + + logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) + logger.info(msg="Began processing of score set variants.", extra=logging_context) - score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == score_set_urn)).one() updated_by = db.scalars(select(User).where(User.id == updater_id)).one() score_set.modified_by = updated_by @@ -210,8 +210,8 @@ async def create_variants_for_score_set( logging_context["processing_state"] = score_set.processing_state.name logger.info(msg="Finished creating variants in score set.", extra=logging_context) - await redis.lpush(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore - await redis.enqueue_job("variant_mapper_manager", correlation_id, score_set_urn, updater_id) + await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore + await redis.enqueue_job("variant_mapper_manager", correlation_id, updater_id) score_set.mapping_state = MappingState.queued finally: db.add(score_set) @@ -224,7 +224,7 @@ async def create_variants_for_score_set( async def map_variants_for_score_set( - ctx: dict, correlation_id: str, score_set_urn: str, updater_id: int, attempt: int = 1 + ctx: dict, correlation_id: str, score_set_id: int, updater_id: int, attempt: int = 1 ) -> dict: async with mapping_in_execution(redis=ctx["redis"], job_id=ctx["job_id"]): logging_context = {} @@ -232,24 +232,27 @@ async def map_variants_for_score_set( try: db: Session = ctx["db"] redis: ArqRedis = ctx["redis"] + score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one() - logging_context = setup_job_state(ctx, updater_id, score_set_urn, correlation_id) + logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id) logging_context["attempt"] = attempt logger.info(msg="Started variant mapping", extra=logging_context) - score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == score_set_urn)).one() score_set.mapping_state = MappingState.processing score_set.mapping_errors = null() db.add(score_set) db.commit() - logging_context["current_mapping_resource"] = score_set.urn + mapping_urn = score_set.urn + assert mapping_urn, "A valid URN is needed to map this score set." + + logging_context["current_mapping_resource"] = mapping_urn logging_context["mapping_state"] = score_set.mapping_state logger.debug(msg="Fetched score set metadata for mapping job.", extra=logging_context) # Do not block Worker event loop during mapping, see: https://arq-docs.helpmanual.io/#synchronous-jobs. vrs = vrs_mapper() - blocking = functools.partial(vrs.map_score_set, score_set_urn) + blocking = functools.partial(vrs.map_score_set, mapping_urn) loop = asyncio.get_running_loop() except Exception as e: @@ -292,13 +295,13 @@ async def map_variants_for_score_set( new_job_id = None max_retries_exceeded = None try: - await redis.lpush(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore + await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - redis, "variant_mapper_manager", attempt, correlation_id, score_set_urn, updater_id + redis, "variant_mapper_manager", attempt, correlation_id, updater_id ) # If we fail to enqueue a mapping manager for this score set, evict it from the queue. if new_job_id is None: - await redis.lpop(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore + await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore logging_context["backoff_limit_exceeded"] = max_retries_exceeded logging_context["backoff_deferred_in_seconds"] = backoff_time @@ -377,7 +380,7 @@ async def map_variants_for_score_set( .join(ScoreSet) .join(TargetSequence) .where( - ScoreSet.urn == str(score_set_urn), + ScoreSet.id == score_set_id, # TargetSequence.sequence == target_sequence, ) ).one() @@ -394,9 +397,7 @@ async def map_variants_for_score_set( }, JSONB, ) - target_gene.post_mapped_metadata = cast( - {"genomic": mapped_genomic_ref}, JSONB - ) + target_gene.post_mapped_metadata = cast({"genomic": mapped_genomic_ref}, JSONB) elif computed_protein_ref and mapped_protein_ref: pre_mapped_metadata = computed_protein_ref target_gene.pre_mapped_metadata = cast( @@ -408,9 +409,7 @@ async def map_variants_for_score_set( }, JSONB, ) - target_gene.post_mapped_metadata = cast( - {"protein": mapped_protein_ref}, JSONB - ) + target_gene.post_mapped_metadata = cast({"protein": mapped_protein_ref}, JSONB) else: raise NonexistentMappingReferenceError() @@ -486,13 +485,13 @@ async def map_variants_for_score_set( new_job_id = None max_retries_exceeded = None try: - await redis.lpush(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore + await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff( - redis, "variant_mapper_manager", attempt, correlation_id, score_set_urn, updater_id + redis, "variant_mapper_manager", attempt, correlation_id, updater_id ) # If we fail to enqueue a mapping manager for this score set, evict it from the queue. if new_job_id is None: - await redis.lpop(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore + await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore logging_context["backoff_limit_exceeded"] = max_retries_exceeded logging_context["backoff_deferred_in_seconds"] = backoff_time @@ -543,34 +542,35 @@ async def map_variants_for_score_set( return {"success": True} -async def variant_mapper_manager( - ctx: dict, correlation_id: str, score_set_urn: str, updater_id: int, attempt: int = 1 -) -> dict: +async def variant_mapper_manager(ctx: dict, correlation_id: str, updater_id: int, attempt: int = 1) -> dict: logging_context = {} mapping_job_id = None mapping_job_status = None + queued_score_set = None try: redis: ArqRedis = ctx["redis"] db: Session = ctx["db"] - logging_context = setup_job_state(ctx, updater_id, score_set_urn, correlation_id) + logging_context = setup_job_state(ctx, updater_id, None, correlation_id) logging_context["attempt"] = attempt logger.debug(msg="Variant mapping manager began execution", extra=logging_context) queue_length = await redis.llen(MAPPING_QUEUE_NAME) # type: ignore - queued_urn = await redis.rpop(MAPPING_QUEUE_NAME) # type: ignore + queued_id = await redis.rpop(MAPPING_QUEUE_NAME) # type: ignore logging_context["variant_mapping_queue_length"] = queue_length # Setup the job id cache if it does not already exist. if not await redis.exists(MAPPING_CURRENT_ID_NAME): await redis.set(MAPPING_CURRENT_ID_NAME, "") - if not queued_urn: + if not queued_id: logger.debug(msg="No mapping jobs exist in the queue.", extra=logging_context) return {"success": True, "enqueued_job": None} else: - queued_urn = queued_urn.decode("utf-8") - logging_context["current_mapping_resource"] = queued_urn + queued_id = queued_id.decode("utf-8") + queued_score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_id)).one() + + logging_context["upcoming_mapping_resource"] = queued_score_set.urn logger.debug(msg="Found mapping job(s) still in queue.", extra=logging_context) mapping_job_id = await redis.get(MAPPING_CURRENT_ID_NAME) @@ -589,15 +589,12 @@ async def variant_mapper_manager( new_job = None new_job_id = None - score_set = None try: if not mapping_job_id or mapping_job_status in (JobStatus.not_found, JobStatus.complete): logger.debug(msg="No mapping jobs are running, queuing a new one.", extra=logging_context) - # NOTE: the score_set_urn provided to this function is only used for logging context; - # get the urn from the queue and pass that urn to map_variants_for_score_set new_job = await redis.enqueue_job( - "map_variants_for_score_set", correlation_id, queued_urn, updater_id, attempt + "map_variants_for_score_set", correlation_id, queued_score_set.id, updater_id, attempt ) if new_job: @@ -616,7 +613,6 @@ async def variant_mapper_manager( new_job = await redis.enqueue_job( "variant_mapper_manager", correlation_id, - score_set_urn, updater_id, attempt, _defer_by=timedelta(minutes=5), @@ -624,7 +620,7 @@ async def variant_mapper_manager( if new_job: # Ensure this score set remains in the front of the queue. - queued_urn = await redis.rpush(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore + queued_id = await redis.rpush(MAPPING_QUEUE_NAME, queued_score_set.id) # type: ignore new_job_id = new_job.job_id logging_context["new_mapping_manager_job_id"] = new_job_id @@ -645,11 +641,16 @@ async def variant_mapper_manager( ) db.rollback() - score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == score_set_urn)).one_or_none() - if score_set: - score_set.mapping_state = MappingState.failed - score_set.mapping_errors = "Unable to queue a new mapping job or defer score set mapping." - db.add(score_set) + + # We shouldn't rely on the passed score set id matching the score set we are operating upon. + if not queued_score_set: + return {"success": False, "enqueued_job": new_job_id} + + score_set_exc = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_score_set.id)).one_or_none() + if score_set_exc: + score_set_exc.mapping_state = MappingState.failed + score_set_exc.mapping_errors = "Unable to queue a new mapping job or defer score set mapping." + db.add(score_set_exc) db.commit() return {"success": False, "enqueued_job": new_job_id} diff --git a/tests/worker/test_jobs.py b/tests/worker/test_jobs.py index d4b6a1e6..15f684b2 100644 --- a/tests/worker/test_jobs.py +++ b/tests/worker/test_jobs.py @@ -129,6 +129,7 @@ async def test_create_variants_for_score_set_with_validation_error( input_score_set, validation_error, setup_worker_db, async_client, standalone_worker_context, session, data_files ): score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() # This is invalid for both data sets. scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "c.1T>A" @@ -141,7 +142,7 @@ async def test_create_variants_for_score_set_with_validation_error( ) as hdp, ): success = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts + standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts ) # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. @@ -169,12 +170,13 @@ async def test_create_variants_for_score_set_with_caught_exception( input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files ): score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee # some exception will be raised no matter what in the async job. with (patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc,): success = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts + standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts ) mocked_exc.assert_called() @@ -197,12 +199,13 @@ async def test_create_variants_for_score_set_with_caught_base_exception( input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files ): score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() # This is somewhat (extra) dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee # some base exception will be handled no matter what in the async job. with (patch.object(pd.DataFrame, "isnull", side_effect=BaseException),): success = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts + standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts ) db_variants = session.scalars(select(Variant)).all() @@ -224,12 +227,13 @@ async def test_create_variants_for_score_set_with_existing_variants( input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files ): score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() with patch.object( cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT ) as hdp: success = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts + standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts ) # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. @@ -249,7 +253,7 @@ async def test_create_variants_for_score_set_with_existing_variants( cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT ) as hdp: success = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts + standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts ) db_variants = session.scalars(select(Variant)).all() @@ -271,6 +275,7 @@ async def test_create_variants_for_score_set_with_existing_exceptions( input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files ): score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() # This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee # some exception will be raised no matter what in the async job. @@ -280,7 +285,7 @@ async def test_create_variants_for_score_set_with_existing_exceptions( ) as mocked_exc, ): success = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts + standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts ) mocked_exc.assert_called() @@ -296,7 +301,7 @@ async def test_create_variants_for_score_set_with_existing_exceptions( cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT ) as hdp: success = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts + standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts ) # Call data provider _get_transcript method if this is an accession based score set, otherwise do not. @@ -324,12 +329,13 @@ async def test_create_variants_for_score_set( input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files ): score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set) + score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one() with patch.object( cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT ) as hdp: success = await create_variants_for_score_set( - standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts + standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts ) # Call data provider _get_transcript method if this is an accession based score set, otherwise do not.