Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Unchanging Identifiers for Worker Processes #308

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/mavedb/routers/score_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ async def upload_score_set_variant_data(
job = await worker.enqueue_job(
"create_variants_for_score_set",
correlation_id_for_context(),
item.urn,
item.id,
user_data.user.id,
scores_df,
counts_df,
Expand Down Expand Up @@ -871,7 +871,7 @@ async def update_score_set(
job = await worker.enqueue_job(
"create_variants_for_score_set",
correlation_id_for_context(),
item.urn,
item.id,
user_data.user.id,
scores_data,
count_data,
Expand Down
91 changes: 46 additions & 45 deletions src/mavedb/worker/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def mapping_in_execution(redis: ArqRedis, job_id: str):
await redis.set(MAPPING_CURRENT_ID_NAME, "")


def setup_job_state(ctx, invoker: int, resource: str, correlation_id: str):
def setup_job_state(ctx, invoker: int, resource: Optional[str], correlation_id: str):
ctx["state"][ctx["job_id"]] = {
"application": "mavedb-worker",
"user": invoker,
Expand Down Expand Up @@ -90,7 +90,7 @@ async def enqueue_job_with_backoff(


async def create_variants_for_score_set(
ctx, correlation_id: str, score_set_urn: str, updater_id: int, scores: pd.DataFrame, counts: pd.DataFrame
ctx, correlation_id: str, score_set_id: int, updater_id: int, scores: pd.DataFrame, counts: pd.DataFrame
):
"""
Create variants for a score set. Intended to be run within a worker.
Expand All @@ -99,14 +99,14 @@ async def create_variants_for_score_set(
"""
logging_context = {}
try:
logging_context = setup_job_state(ctx, updater_id, score_set_urn, correlation_id)
logger.info(msg="Began processing of score set variants.", extra=logging_context)

db: Session = ctx["db"]
hdp: RESTDataProvider = ctx["hdp"]
redis: ArqRedis = ctx["redis"]
score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one()

logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id)
logger.info(msg="Began processing of score set variants.", extra=logging_context)

score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == score_set_urn)).one()
updated_by = db.scalars(select(User).where(User.id == updater_id)).one()

score_set.modified_by = updated_by
Expand Down Expand Up @@ -210,8 +210,8 @@ async def create_variants_for_score_set(
logging_context["processing_state"] = score_set.processing_state.name
logger.info(msg="Finished creating variants in score set.", extra=logging_context)

await redis.lpush(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore
await redis.enqueue_job("variant_mapper_manager", correlation_id, score_set_urn, updater_id)
await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore
await redis.enqueue_job("variant_mapper_manager", correlation_id, updater_id)
score_set.mapping_state = MappingState.queued
finally:
db.add(score_set)
Expand All @@ -224,32 +224,35 @@ async def create_variants_for_score_set(


async def map_variants_for_score_set(
ctx: dict, correlation_id: str, score_set_urn: str, updater_id: int, attempt: int = 1
ctx: dict, correlation_id: str, score_set_id: int, updater_id: int, attempt: int = 1
) -> dict:
async with mapping_in_execution(redis=ctx["redis"], job_id=ctx["job_id"]):
logging_context = {}
score_set = None
try:
db: Session = ctx["db"]
redis: ArqRedis = ctx["redis"]
score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == score_set_id)).one()

logging_context = setup_job_state(ctx, updater_id, score_set_urn, correlation_id)
logging_context = setup_job_state(ctx, updater_id, score_set.urn, correlation_id)
logging_context["attempt"] = attempt
logger.info(msg="Started variant mapping", extra=logging_context)

score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == score_set_urn)).one()
score_set.mapping_state = MappingState.processing
score_set.mapping_errors = null()
db.add(score_set)
db.commit()

logging_context["current_mapping_resource"] = score_set.urn
mapping_urn = score_set.urn
assert mapping_urn, "A valid URN is needed to map this score set."

logging_context["current_mapping_resource"] = mapping_urn
logging_context["mapping_state"] = score_set.mapping_state
logger.debug(msg="Fetched score set metadata for mapping job.", extra=logging_context)

# Do not block Worker event loop during mapping, see: https://arq-docs.helpmanual.io/#synchronous-jobs.
vrs = vrs_mapper()
blocking = functools.partial(vrs.map_score_set, score_set_urn)
blocking = functools.partial(vrs.map_score_set, mapping_urn)
loop = asyncio.get_running_loop()

except Exception as e:
Expand Down Expand Up @@ -292,13 +295,13 @@ async def map_variants_for_score_set(
new_job_id = None
max_retries_exceeded = None
try:
await redis.lpush(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore
await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore
new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff(
redis, "variant_mapper_manager", attempt, correlation_id, score_set_urn, updater_id
redis, "variant_mapper_manager", attempt, correlation_id, updater_id
)
# If we fail to enqueue a mapping manager for this score set, evict it from the queue.
if new_job_id is None:
await redis.lpop(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore
await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore

logging_context["backoff_limit_exceeded"] = max_retries_exceeded
logging_context["backoff_deferred_in_seconds"] = backoff_time
Expand Down Expand Up @@ -377,7 +380,7 @@ async def map_variants_for_score_set(
.join(ScoreSet)
.join(TargetSequence)
.where(
ScoreSet.urn == str(score_set_urn),
ScoreSet.id == score_set_id,
# TargetSequence.sequence == target_sequence,
)
).one()
Expand All @@ -394,9 +397,7 @@ async def map_variants_for_score_set(
},
JSONB,
)
target_gene.post_mapped_metadata = cast(
{"genomic": mapped_genomic_ref}, JSONB
)
target_gene.post_mapped_metadata = cast({"genomic": mapped_genomic_ref}, JSONB)
elif computed_protein_ref and mapped_protein_ref:
pre_mapped_metadata = computed_protein_ref
target_gene.pre_mapped_metadata = cast(
Expand All @@ -408,9 +409,7 @@ async def map_variants_for_score_set(
},
JSONB,
)
target_gene.post_mapped_metadata = cast(
{"protein": mapped_protein_ref}, JSONB
)
target_gene.post_mapped_metadata = cast({"protein": mapped_protein_ref}, JSONB)
else:
raise NonexistentMappingReferenceError()

Expand Down Expand Up @@ -486,13 +485,13 @@ async def map_variants_for_score_set(
new_job_id = None
max_retries_exceeded = None
try:
await redis.lpush(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore
await redis.lpush(MAPPING_QUEUE_NAME, score_set.id) # type: ignore
new_job_id, max_retries_exceeded, backoff_time = await enqueue_job_with_backoff(
redis, "variant_mapper_manager", attempt, correlation_id, score_set_urn, updater_id
redis, "variant_mapper_manager", attempt, correlation_id, updater_id
)
# If we fail to enqueue a mapping manager for this score set, evict it from the queue.
if new_job_id is None:
await redis.lpop(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore
await redis.lpop(MAPPING_QUEUE_NAME, score_set.id) # type: ignore

logging_context["backoff_limit_exceeded"] = max_retries_exceeded
logging_context["backoff_deferred_in_seconds"] = backoff_time
Expand Down Expand Up @@ -543,34 +542,35 @@ async def map_variants_for_score_set(
return {"success": True}


async def variant_mapper_manager(
ctx: dict, correlation_id: str, score_set_urn: str, updater_id: int, attempt: int = 1
) -> dict:
async def variant_mapper_manager(ctx: dict, correlation_id: str, updater_id: int, attempt: int = 1) -> dict:
logging_context = {}
mapping_job_id = None
mapping_job_status = None
queued_score_set = None
try:
redis: ArqRedis = ctx["redis"]
db: Session = ctx["db"]

logging_context = setup_job_state(ctx, updater_id, score_set_urn, correlation_id)
logging_context = setup_job_state(ctx, updater_id, None, correlation_id)
logging_context["attempt"] = attempt
logger.debug(msg="Variant mapping manager began execution", extra=logging_context)

queue_length = await redis.llen(MAPPING_QUEUE_NAME) # type: ignore
queued_urn = await redis.rpop(MAPPING_QUEUE_NAME) # type: ignore
queued_id = await redis.rpop(MAPPING_QUEUE_NAME) # type: ignore
logging_context["variant_mapping_queue_length"] = queue_length

# Setup the job id cache if it does not already exist.
if not await redis.exists(MAPPING_CURRENT_ID_NAME):
await redis.set(MAPPING_CURRENT_ID_NAME, "")

if not queued_urn:
if not queued_id:
logger.debug(msg="No mapping jobs exist in the queue.", extra=logging_context)
return {"success": True, "enqueued_job": None}
else:
queued_urn = queued_urn.decode("utf-8")
logging_context["current_mapping_resource"] = queued_urn
queued_id = queued_id.decode("utf-8")
queued_score_set = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_id)).one()

logging_context["upcoming_mapping_resource"] = queued_score_set.urn
logger.debug(msg="Found mapping job(s) still in queue.", extra=logging_context)

mapping_job_id = await redis.get(MAPPING_CURRENT_ID_NAME)
Expand All @@ -589,15 +589,12 @@ async def variant_mapper_manager(

new_job = None
new_job_id = None
score_set = None
try:
if not mapping_job_id or mapping_job_status in (JobStatus.not_found, JobStatus.complete):
logger.debug(msg="No mapping jobs are running, queuing a new one.", extra=logging_context)

# NOTE: the score_set_urn provided to this function is only used for logging context;
# get the urn from the queue and pass that urn to map_variants_for_score_set
new_job = await redis.enqueue_job(
"map_variants_for_score_set", correlation_id, queued_urn, updater_id, attempt
"map_variants_for_score_set", correlation_id, queued_score_set.id, updater_id, attempt
)

if new_job:
Expand All @@ -616,15 +613,14 @@ async def variant_mapper_manager(
new_job = await redis.enqueue_job(
"variant_mapper_manager",
correlation_id,
score_set_urn,
updater_id,
attempt,
_defer_by=timedelta(minutes=5),
)

if new_job:
# Ensure this score set remains in the front of the queue.
queued_urn = await redis.rpush(MAPPING_QUEUE_NAME, score_set_urn) # type: ignore
queued_id = await redis.rpush(MAPPING_QUEUE_NAME, queued_score_set.id) # type: ignore
new_job_id = new_job.job_id

logging_context["new_mapping_manager_job_id"] = new_job_id
Expand All @@ -645,11 +641,16 @@ async def variant_mapper_manager(
)

db.rollback()
score_set = db.scalars(select(ScoreSet).where(ScoreSet.urn == score_set_urn)).one_or_none()
if score_set:
score_set.mapping_state = MappingState.failed
score_set.mapping_errors = "Unable to queue a new mapping job or defer score set mapping."
db.add(score_set)

# We shouldn't rely on the passed score set id matching the score set we are operating upon.
if not queued_score_set:
return {"success": False, "enqueued_job": new_job_id}

score_set_exc = db.scalars(select(ScoreSet).where(ScoreSet.id == queued_score_set.id)).one_or_none()
if score_set_exc:
score_set_exc.mapping_state = MappingState.failed
score_set_exc.mapping_errors = "Unable to queue a new mapping job or defer score set mapping."
db.add(score_set_exc)
db.commit()

return {"success": False, "enqueued_job": new_job_id}
22 changes: 14 additions & 8 deletions tests/worker/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ async def test_create_variants_for_score_set_with_validation_error(
input_score_set, validation_error, setup_worker_db, async_client, standalone_worker_context, session, data_files
):
score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set)
score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one()

# This is invalid for both data sets.
scores.loc[:, HGVS_NT_COLUMN].iloc[0] = "c.1T>A"
Expand All @@ -141,7 +142,7 @@ async def test_create_variants_for_score_set_with_validation_error(
) as hdp,
):
success = await create_variants_for_score_set(
standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts
standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts
)

# Call data provider _get_transcript method if this is an accession based score set, otherwise do not.
Expand Down Expand Up @@ -169,12 +170,13 @@ async def test_create_variants_for_score_set_with_caught_exception(
input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files
):
score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set)
score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one()

# This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee
# some exception will be raised no matter what in the async job.
with (patch.object(pd.DataFrame, "isnull", side_effect=Exception) as mocked_exc,):
success = await create_variants_for_score_set(
standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts
standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts
)
mocked_exc.assert_called()

Expand All @@ -197,12 +199,13 @@ async def test_create_variants_for_score_set_with_caught_base_exception(
input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files
):
score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set)
score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one()

# This is somewhat (extra) dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee
# some base exception will be handled no matter what in the async job.
with (patch.object(pd.DataFrame, "isnull", side_effect=BaseException),):
success = await create_variants_for_score_set(
standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts
standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts
)

db_variants = session.scalars(select(Variant)).all()
Expand All @@ -224,12 +227,13 @@ async def test_create_variants_for_score_set_with_existing_variants(
input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files
):
score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set)
score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one()

with patch.object(
cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT
) as hdp:
success = await create_variants_for_score_set(
standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts
standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts
)

# Call data provider _get_transcript method if this is an accession based score set, otherwise do not.
Expand All @@ -249,7 +253,7 @@ async def test_create_variants_for_score_set_with_existing_variants(
cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT
) as hdp:
success = await create_variants_for_score_set(
standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts
standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts
)

db_variants = session.scalars(select(Variant)).all()
Expand All @@ -271,6 +275,7 @@ async def test_create_variants_for_score_set_with_existing_exceptions(
input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files
):
score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set)
score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one()

# This is somewhat dumb and wouldn't actually happen like this, but it serves as an effective way to guarantee
# some exception will be raised no matter what in the async job.
Expand All @@ -280,7 +285,7 @@ async def test_create_variants_for_score_set_with_existing_exceptions(
) as mocked_exc,
):
success = await create_variants_for_score_set(
standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts
standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts
)
mocked_exc.assert_called()

Expand All @@ -296,7 +301,7 @@ async def test_create_variants_for_score_set_with_existing_exceptions(
cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT
) as hdp:
success = await create_variants_for_score_set(
standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts
standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts
)

# Call data provider _get_transcript method if this is an accession based score set, otherwise do not.
Expand Down Expand Up @@ -324,12 +329,13 @@ async def test_create_variants_for_score_set(
input_score_set, setup_worker_db, async_client, standalone_worker_context, session, data_files
):
score_set_urn, scores, counts = await setup_records_and_files(async_client, data_files, input_score_set)
score_set = session.scalars(select(ScoreSetDbModel).where(ScoreSetDbModel.urn == score_set_urn)).one()

with patch.object(
cdot.hgvs.dataproviders.RESTDataProvider, "_get_transcript", return_value=TEST_CDOT_TRANSCRIPT
) as hdp:
success = await create_variants_for_score_set(
standalone_worker_context, uuid4().hex, score_set_urn, 1, scores, counts
standalone_worker_context, uuid4().hex, score_set.id, 1, scores, counts
)

# Call data provider _get_transcript method if this is an accession based score set, otherwise do not.
Expand Down