diff --git a/alchemiscale/compute/api.py b/alchemiscale/compute/api.py index db21d5b8..f3bff55c 100644 --- a/alchemiscale/compute/api.py +++ b/alchemiscale/compute/api.py @@ -12,6 +12,7 @@ from fastapi import FastAPI, APIRouter, Body, Depends from fastapi.middleware.gzip import GZipMiddleware from gufe.tokenization import GufeTokenizable, JSON_HANDLER +from gufe.protocols import ProtocolDAGResult from ..base.api import ( QueryGUFEHandler, @@ -248,7 +249,7 @@ def set_task_result( validate_scopes(task_sk.scope, token) pdr = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder) - pdr = GufeTokenizable.from_dict(pdr) + pdr: ProtocolDAGResult = GufeTokenizable.from_dict(pdr) tf_sk, _ = n4js.get_task_transformation( task=task_scoped_key, @@ -270,7 +271,11 @@ def set_task_result( if protocoldagresultref.ok: n4js.set_task_complete(tasks=[task_sk]) else: + n4js.add_protocol_dag_result_ref_tracebacks( + pdr.protocol_unit_failures, result_sk + ) n4js.set_task_error(tasks=[task_sk]) + n4js.resolve_task_restarts(tasks=[task_sk]) return result_sk diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index 3dc69e0d..1d8e1679 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -202,9 +202,11 @@ def __eq__(self, other): return self.pattern == other.pattern -class Traceback(GufeTokenizable): +class Tracebacks(GufeTokenizable): - def __init__(self, tracebacks: List[str]): + def __init__( + self, tracebacks: List[str], source_keys: List[str], failure_keys: List[str] + ): value_error = ValueError( "`tracebacks` must be a non-empty list of string values" ) @@ -216,21 +218,25 @@ def __init__(self, tracebacks: List[str]): if not all_string_values or "" in tracebacks: raise value_error + # TODO: validate self.tracebacks = tracebacks - - def _gufe_tokenize(self): - return hashlib.md5(str(self.tracebacks).encode()).hexdigest() + self.source_keys = source_keys + self.failure_keys = failure_keys @classmethod def _defaults(cls): - raise NotImplementedError + return super()._defaults() @classmethod def _from_dict(cls, dct): - return Traceback(**dct) + return cls(**dct) def _to_dict(self): - return {"tracebacks": self.tracebacks} + return { + "tracebacks": self.tracebacks, + "source_keys": self.source_keys, + "failure_keys": self.failure_keys, + } class TaskHub(GufeTokenizable): diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 07d05d02..03c8f6d5 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -8,14 +8,18 @@ from datetime import datetime from contextlib import contextmanager import json -from functools import lru_cache +import re +from functools import lru_cache, update_wrapper from typing import Dict, List, Optional, Union, Tuple, Set +from collections import defaultdict +from collections.abc import Iterable import weakref import numpy as np import networkx as nx from gufe import AlchemicalNetwork, Transformation, NonTransformation, Settings from gufe.tokenization import GufeTokenizable, GufeKey, JSON_HANDLER +from gufe.protocols import ProtocolUnitFailure from neo4j import Transaction, GraphDatabase, Driver @@ -29,6 +33,7 @@ TaskHub, TaskRestartPattern, TaskStatusEnum, + Tracebacks, ) from ..strategies import Strategy from ..models import Scope, ScopedKey @@ -173,6 +178,19 @@ def transaction(self, ignore_exceptions=False) -> Transaction: else: tx.commit() + def chainable(func): + def inner(self, *args, **kwargs): + if kwargs.get("tx") is not None: + return func(self, *args, **kwargs) + + with self.transaction() as tx: + kwargs.update(tx=tx) + return func(self, *args, **kwargs) + + update_wrapper(inner, func) + + return inner + def execute_query(self, *args, **kwargs): kwargs.update({"database_": self.db_name}) return self.graph.execute_query(*args, **kwargs) @@ -1588,10 +1606,12 @@ def get_task_weights( return weights + @chainable def cancel_tasks( self, tasks: List[ScopedKey], taskhub: ScopedKey, + tx=None, ) -> List[Union[ScopedKey, None]]: """Remove Tasks from the TaskHub for a given AlchemicalNetwork. @@ -1602,31 +1622,30 @@ def cancel_tasks( """ canceled_sks = [] - with self.transaction() as tx: - for task in tasks: - query = """ - // get our task hub, as well as the task :ACTIONS relationship we want to remove - MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: $task_scoped_key}) - DELETE ar - - WITH task - CALL { - WITH task - MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern) - DELETE applies - } - - RETURN task - """ - _task = tx.run( - query, taskhub_scoped_key=str(taskhub), task_scoped_key=str(task) - ).to_eager_result() + for task in tasks: + query = """ + // get our task hub, as well as the task :ACTIONS relationship we want to remove + MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: $task_scoped_key}) + DELETE ar - if _task.records: - sk = _task.records[0].data()["task"]["_scoped_key"] - canceled_sks.append(ScopedKey.from_str(sk)) - else: - canceled_sks.append(None) + WITH task, th + CALL { + WITH task, th + MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern)-[:ENFORCES]->(th) + DELETE applies + } + + RETURN task + """ + _task = tx.run( + query, taskhub_scoped_key=str(taskhub), task_scoped_key=str(task) + ).to_eager_result() + + if _task.records: + sk = _task.records[0].data()["task"]["_scoped_key"] + canceled_sks.append(ScopedKey.from_str(sk)) + else: + canceled_sks.append(None) return canceled_sks @@ -2416,6 +2435,59 @@ def get_task_failures(self, task: ScopedKey) -> List[ProtocolDAGResultRef]: """ return self._get_protocoldagresultrefs(q, task) + def add_protocol_dag_result_ref_tracebacks( + self, + protocol_unit_failures: List[ProtocolUnitFailure], + protocol_dag_result_ref_scoped_key: ScopedKey, + ): + subgraph = Subgraph() + + with self.transaction() as tx: + + query = """ + MATCH (pdrr:ProtocolDAGResultRef {`_scoped_key`: $protocol_dag_result_ref_scoped_key}) + RETURN pdrr + """ + + pdrr_result = tx.run( + query, + protocol_dag_result_ref_scoped_key=str( + protocol_dag_result_ref_scoped_key + ), + ).to_eager_result() + + try: + protocol_dag_result_ref_node = record_data_to_node( + pdrr_result.records[0]["pdrr"] + ) + except IndexError: + raise KeyError("Could not find ProtocolDAGResultRef in database.") + + failure_keys = [] + source_keys = [] + tracebacks = [] + + for puf in protocol_unit_failures: + failure_keys.append(puf.key) + source_keys.append(puf.source_key) + tracebacks.append(puf.traceback) + + traceback = Tracebacks(tracebacks, source_keys, failure_keys) + + _, traceback_node, _ = self._gufe_to_subgraph( + traceback.to_shallow_dict(), + labels=["GufeTokenizable", traceback.__class__.__name__], + gufe_key=traceback.key, + scope=protocol_dag_result_ref_scoped_key.scope, + ) + + subgraph |= Relationship.type("DETAILS")( + traceback_node, + protocol_dag_result_ref_node, + ) + + merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") + def set_task_status( self, tasks: List[ScopedKey], status: TaskStatusEnum, raise_error: bool = False ) -> List[Optional[ScopedKey]]: @@ -2778,15 +2850,17 @@ def add_task_restart_patterns( RETURN task """ + actioned_task_records = ( + tx.run(actioned_tasks_query, taskhub_scoped_key=str(taskhub)) + .to_eager_result() + .records + ) + subgraph = Subgraph() actioned_task_nodes = [] - for actioned_tasks_record in ( - tx.run(actioned_tasks_query, taskhub_scoped_key=str(taskhub)) - .to_eager_result() - .records - ): + for actioned_tasks_record in actioned_task_records: actioned_task_nodes.append( record_data_to_node(actioned_tasks_record["task"]) ) @@ -2821,6 +2895,15 @@ def add_task_restart_patterns( ) merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") + actioned_task_scoped_keys: List[ScopedKey] = [] + + for actioned_task_record in actioned_task_records: + actioned_task_scoped_keys.append( + ScopedKey.from_str(actioned_task_record["task"]["_scoped_key"]) + ) + + self.resolve_task_restarts(actioned_task_scoped_keys, tx=tx) + # TODO: fill in docstring def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]): q = """ @@ -2854,6 +2937,7 @@ def set_task_restart_patterns_max_retries( ) # TODO: fill in docstring + # TODO: validation of taskhubs variable, will fail in weird ways if not enforced def get_task_restart_patterns( self, taskhubs: List[ScopedKey] ) -> Dict[ScopedKey, Set[Tuple[str, int]]]: @@ -2868,7 +2952,9 @@ def get_task_restart_patterns( q, taskhub_scoped_keys=list(map(str, taskhubs)) ).records - data = {taskhub: set() for taskhub in taskhubs} + data: dict[ScopedKey, set[tuple[str, int]]] = { + taskhub: set() for taskhub in taskhubs + } for record in records: pattern = record["trp"]["pattern"] @@ -2878,6 +2964,117 @@ def get_task_restart_patterns( return data + # TODO: docstrings + @chainable + def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=None): + + # Given the scoped keys of a list of Tasks, find all tasks that have an + # error status and have a TaskRestartPattern applied. A subquery is executed + # to optionally get the latest traceback associated with the task + query = """ + UNWIND $task_scoped_keys AS task_scoped_key + MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern)-[:ENFORCES]->(taskhub:TaskHub) + CALL { + WITH task + OPTIONAL MATCH (task:Task)-[:RESULTS_IN]->(pdrr:ProtocolDAGResultRef)<-[:DETAILS]-(tracebacks:Tracebacks) + RETURN tracebacks + ORDER BY pdrr.datetime_created DESCENDING + LIMIT 1 + } + WITH task, tracebacks, trp, app, taskhub + RETURN task, tracebacks, trp, app, taskhub + """ + + results = tx.run( + query, + task_scoped_keys=list(map(str, task_scoped_keys)), + error=TaskStatusEnum.error.value, + ).to_eager_result() + + if not results: + return + + # iterate over all of the results to determine if an applied pattern needs + # to be iterated or if the task needs to be cancelled outright + + # Keep track of which task/taskhub pairs would need to be canceled + # None => the pair never had a matching restart pattern + # True => at least one patterns max_retries was exceeded + # False => at least one regex matched, but no pattern max_retries were exceeded + cancel_map: dict[Tuple[str, str], Optional[bool]] = {} + to_increment: List[Tuple[str, str]] = [] + all_task_taskhub_pairs: set[Tuple[str, str]] = set() + for record in results.records: + task_restart_pattern = record["trp"] + applies_relationship = record["app"] + task = record["task"] + taskhub = record["taskhub"] + _tracebacks = record["tracebacks"] + + task_taskhub_tuple = (task["_scoped_key"], taskhub["_scoped_key"]) + + all_task_taskhub_pairs.add(task_taskhub_tuple) + + # TODO: remove in v1.0.0 + # tasks that errored, prior to the indtroduction of task restart policies will have no tracebacks in the database + if _tracebacks is None: + cancel_map[task_taskhub_tuple] = True + + # we have already determined that the task is to be canceled. + # this is only ever truthy when we say a task needs to be canceled. + if cancel_map.get(task_taskhub_tuple): + continue + + num_retries = applies_relationship["num_retries"] + max_retries = task_restart_pattern["max_retries"] + pattern = task_restart_pattern["pattern"] + tracebacks: List[str] = _tracebacks["tracebacks"] + + compiled_pattern = re.compile(pattern) + + if any([compiled_pattern.search(message) for message in tracebacks]): + if num_retries + 1 > max_retries: + cancel_map[task_taskhub_tuple] = True + else: + to_increment.append( + (task["_scoped_key"], task_restart_pattern["_scoped_key"]) + ) + cancel_map[task_taskhub_tuple] = False + + increment_query = """ + UNWIND $trp_and_task_pairs as pairs + WITH pairs[0] as task_scoped_key, pairs[1] as task_restart_pattern_scoped_key + MATCH (:Task {`_scoped_key`: task_scoped_key})<-[app:APPLIES]-(:TaskRestartPattern {`_scoped_key`: task_restart_pattern_scoped_key}) + SET app.num_retries = app.num_retries + 1 + """ + + tx.run(increment_query, trp_and_task_pairs=to_increment) + + # cancel all tasks that didn't trigger any restart patterns (None) + # or exceeded a patterns max_retries value (True) + cancel_groups: defaultdict[str, list[str]] = defaultdict(list) + for task_taskhub_pair in all_task_taskhub_pairs: + cancel_result = cancel_map.get(task_taskhub_pair) + if cancel_result is True or cancel_result is None: + cancel_groups[task_taskhub_pair[1]].append(task_taskhub_pair[0]) + + for taskhub, tasks in cancel_groups.items(): + self.cancel_tasks(tasks, taskhub, tx=tx) + + # any remaining tasks must then be okay to switch to waiting + renew_waiting_status_query = """ + UNWIND $task_scoped_keys AS task_scoped_key + MATCH (task:Task {status: $error, `_scoped_key`: task_scoped_key})<-[app:APPLIES]-(trp:TaskRestartPattern)-[:ENFORCES]->(taskhub:TaskHub) + SET task.status = $waiting + """ + + tx.run( + renew_waiting_status_query, + task_scoped_keys=list(map(str, task_scoped_keys)), + waiting=TaskStatusEnum.waiting.value, + error=TaskStatusEnum.error.value, + ) + ## authentication def create_credentialed_entity(self, entity: CredentialedEntity): diff --git a/alchemiscale/tests/integration/conftest.py b/alchemiscale/tests/integration/conftest.py index 1875981e..ed9e2b31 100644 --- a/alchemiscale/tests/integration/conftest.py +++ b/alchemiscale/tests/integration/conftest.py @@ -177,6 +177,49 @@ def n4js_fresh(graph): return n4js +@fixture +def n4js_task_restart_policy( + n4js_fresh: Neo4jStore, network_tyk2: AlchemicalNetwork, scope_test +): + + n4js = n4js_fresh + + _, taskhub_scoped_key_with_policy, _ = n4js.assemble_network( + network_tyk2, scope_test + ) + + _, taskhub_scoped_key_no_policy, _ = n4js.assemble_network( + network_tyk2.copy_with_replacements(name=network_tyk2.name + "_no_policy"), + scope_test, + ) + + transformation_1_scoped_key, transformation_2_scoped_key = map( + lambda transformation: n4js.get_scoped_key(transformation, scope_test), + list(network_tyk2.edges)[:2], + ) + + # create 4 tasks for each of the 2 selected transformations + task_scoped_keys = n4js.create_tasks( + [transformation_1_scoped_key] * 4 + [transformation_2_scoped_key] * 4 + ) + + # action the tasks for transformation 1 on the taskhub with no policy + # action the tasks for both transformations on the taskhub with a policy + assert all(n4js.action_tasks(task_scoped_keys[:4], taskhub_scoped_key_no_policy)) + assert all(n4js.action_tasks(task_scoped_keys, taskhub_scoped_key_with_policy)) + + patterns = [ + r"Error message \d, round \d", + "This is an example pattern that will be used as a restart string.", + ] + + n4js.add_task_restart_patterns( + taskhub_scoped_key_with_policy, patterns=patterns, number_of_retries=2 + ) + + return n4js + + @fixture(scope="module") def s3objectstore_settings(): os.environ["AWS_ACCESS_KEY_ID"] = "test-key-id" diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index c7901840..9a5e71ef 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -2,12 +2,15 @@ import random from typing import List, Dict from pathlib import Path +from functools import reduce from itertools import chain +import operator from collections import defaultdict import pytest from gufe import AlchemicalNetwork from gufe.tokenization import TOKENIZABLE_REGISTRY +from gufe.protocols import ProtocolUnitFailure from gufe.protocols.protocoldag import execute_DAG from alchemiscale.storage.statestore import Neo4jStore @@ -28,6 +31,14 @@ ) from alchemiscale.security.auth import hash_key +from alchemiscale.tests.integration.storage.utils import ( + complete_tasks, + fail_task, + tasks_are_errored, + tasks_are_not_actioned_on_taskhub, + tasks_are_waiting, +) + class TestStateStore: ... @@ -1944,6 +1955,65 @@ def test_get_task_failures( assert pdr_ref_sk in failure_pdr_ref_sks assert pdr_ref2_sk in failure_pdr_ref_sks + @pytest.mark.parametrize("failure_count", (1, 2, 3, 4)) + def test_add_protocol_dag_result_ref_traceback( + self, + network_tyk2_failure, + n4js, + scope_test, + transformation_failure, + protocoldagresults_failure, + failure_count: int, + ): + + an = network_tyk2_failure.copy_with_replacements( + name=network_tyk2_failure.name + + "_test_add_protocol_dag_result_ref_traceback" + ) + n4js.assemble_network(an, scope_test) + transformation_scoped_key = n4js.get_scoped_key( + transformation_failure, scope_test + ) + + # create a task; pretend we computed it, submit reference for pre-baked + # result + task_scoped_key = n4js.create_task(transformation_scoped_key) + + protocol_unit_failure = protocoldagresults_failure[0].protocol_unit_failures[0] + + pdrr = ProtocolDAGResultRef( + scope=task_scoped_key.scope, + obj_key=protocoldagresults_failure[0].key, + ok=protocoldagresults_failure[0].ok(), + ) + + # push the result + pdrr_scoped_key = n4js.set_task_result(task_scoped_key, pdrr) + + # simulating many failures + protocol_unit_failures = [] + for failure_index in range(failure_count): + protocol_unit_failures.append( + protocol_unit_failure.copy_with_replacements( + traceback=protocol_unit_failure.traceback + "_" + str(failure_index) + ) + ) + + n4js.add_protocol_dag_result_ref_tracebacks( + protocol_unit_failures, pdrr_scoped_key + ) + + query = """ + MATCH (traceback:Tracebacks)-[:DETAILS]->(:ProtocolDAGResultRef {`_scoped_key`: $pdrr_scoped_key}) + RETURN traceback + """ + + results = n4js.execute_query(query, pdrr_scoped_key=str(pdrr_scoped_key)) + + returned_tracebacks = results.records[0]["traceback"]["tracebacks"] + + assert returned_tracebacks == [puf.traceback for puf in protocol_unit_failures] + ### task restart policies class TestTaskRestartPolicy: @@ -1951,7 +2021,7 @@ class TestTaskRestartPolicy: @pytest.mark.parametrize("status", ("complete", "invalid", "deleted")) def test_task_status_change(self, n4js, network_tyk2, scope_test, status): an = network_tyk2.copy_with_replacements( - name=network_tyk2.name + f"_test_task_status_change" + name=network_tyk2.name + "_test_task_status_change" ) _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) transformation = list(an.edges)[0] @@ -2225,6 +2295,192 @@ def test_get_task_restart_patterns(self, n4js, network_tyk2, scope_test): assert taskhub_grouped_patterns == expected_results + def test_resolve_task_restarts( + self, + scope_test: Scope, + n4js_task_restart_policy: Neo4jStore, + ): + n4js = n4js_task_restart_policy + + # get the actioned tasks for each taskhub + taskhub_actioned_tasks = {} + for taskhub_scoped_key in n4js.query_taskhubs(): + taskhub_actioned_tasks[taskhub_scoped_key] = set( + n4js.get_taskhub_actioned_tasks([taskhub_scoped_key])[0] + ) + + restart_patterns = n4js.get_task_restart_patterns( + list(taskhub_actioned_tasks.keys()) + ) + + # create a map of the transformations and all of the tasks that perform them + transformation_tasks: dict[ScopedKey, list[ScopedKey]] = defaultdict(list) + for task in n4js.query_tasks(status=TaskStatusEnum.waiting.value): + transformation_scoped_key, _ = n4js.get_task_transformation( + task, return_gufe=False + ) + transformation_tasks[transformation_scoped_key].append(task) + + # get a list of all tasks for more convient calls of the resolve method + all_tasks = [] + for task_group in transformation_tasks.values(): + all_tasks.extend(task_group) + + taskhub_scoped_key_no_policy = None + taskhub_scoped_key_with_policy = None + + # bind taskhub scoped keys to variables for convenience later + for taskhub_scoped_key, patterns in restart_patterns.items(): + if not patterns: + taskhub_scoped_key_no_policy = taskhub_scoped_key + continue + else: + taskhub_scoped_key_with_policy = taskhub_scoped_key + continue + + if patterns and taskhub_scoped_key_with_policy: + raise AssertionError("More than one TaskHub has restart patterns") + + assert ( + taskhub_scoped_key_no_policy + and taskhub_scoped_key_with_policy + and (taskhub_scoped_key_no_policy != taskhub_scoped_key_with_policy) + ) + + # we first check the behavior involving tasks that are actioned by both taskhubs + # this involves confirming: + # + # 1. Completed Tasks do not have an actions relationship with either TaskHub + # 2. A Task entering the error state is switched back to waiting if any restart patterns apply + # 3. A Task entering the error state is left in the error state if no patterns apply and only the TaskHub without + # an enforcing task restart policy actions the Task + # + # Tasks will be set to the error state with a spoofing method, which will create a fake ProtocolDAGResultRef + # and Tracebacks. This is done since making a protocol fail systematically in the testing environment is not + # obvious at this time. + + # reduce down all tasks until only the common elements between taskhubs exist + tasks_actioned_by_all_taskhubs: List[ScopedKey] = list( + reduce(operator.and_, taskhub_actioned_tasks.values()) + ) + + assert len(tasks_actioned_by_all_taskhubs) == 4 + + # we're going to just pass the first 2 and fail the second 2 + tasks_to_complete = tasks_actioned_by_all_taskhubs[:2] + tasks_to_fail = tasks_actioned_by_all_taskhubs[2:] + + complete_tasks(n4js, tasks_to_complete) + + records = n4js.execute_query( + """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {_scoped_key: task_scoped_key})-[:RESULTS_IN]->(:ProtocolDAGResultRef) + RETURN count(task) as task_count + """, + task_scoped_keys=list(map(str, tasks_to_complete)), + ).records + + assert records[0]["task_count"] == 2 + + # test the behavior of the compute API + for i, task in enumerate(tasks_to_fail): + error_messages = [ + f"Error message {repeat}, round {i}" for repeat in range(3) + ] + + fail_task( + n4js, + task, + resolve=False, + error_messages=error_messages, + ) + + n4js.resolve_task_restarts(all_tasks) + + # both tasks should have the waiting status and the APPLIES + # relationship num_retries should have incremented by 1 + query = """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {`_scoped_key`: task_scoped_key, status: $waiting})<-[:APPLIES {num_retries: 1}]-(:TaskRestartPattern {max_retries: 2}) + RETURN count(DISTINCT task) as renewed_waiting_tasks + """ + + renewed_waiting = n4js.execute_query( + query, + task_scoped_keys=list(map(str, tasks_to_fail)), + waiting=TaskStatusEnum.waiting.value, + ).records[0]["renewed_waiting_tasks"] + + assert renewed_waiting == 2 + + # we want the resolve restarts to cancel a task. + # deconstruct the tasks to fail, where the first + # one will be cancelled and the second will continue to wait + task_to_cancel, task_to_wait = tasks_to_fail + + # error out the first task + for _ in range(2): + error_messages = [ + f"Error message {repeat}, round {i}" for repeat in range(3) + ] + + fail_task( + n4js, + task_to_cancel, + resolve=False, + error_messages=error_messages, + ) + + n4js.resolve_task_restarts(tasks_to_fail) + + # check that it is no longer actioned on the enforced taskhub + assert tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_cancel], + taskhub_scoped_key_with_policy, + ) + + # check that it is still actioned on the unenforced taskhub + assert not tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_cancel], + taskhub_scoped_key_no_policy, + ) + + # it should still be errored though! + assert tasks_are_errored(n4js, [task_to_cancel]) + + # fail the second task one time + error_messages = [ + f"Error message {repeat}, round {i}" for repeat in range(3) + ] + + fail_task( + n4js, + task_to_wait, + resolve=False, + error_messages=error_messages, + ) + + n4js.resolve_task_restarts(tasks_to_fail) + + # check that the waiting task is actioned on both taskhubs + assert not tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_wait], + taskhub_scoped_key_with_policy, + ) + + assert not tasks_are_not_actioned_on_taskhub( + n4js, + [task_to_wait], + taskhub_scoped_key_no_policy, + ) + + # it should be waiting + assert tasks_are_waiting(n4js, [task_to_wait]) + @pytest.mark.xfail(raises=NotImplementedError) def test_task_actioning_applies_relationship(self): raise NotImplementedError diff --git a/alchemiscale/tests/integration/storage/utils.py b/alchemiscale/tests/integration/storage/utils.py new file mode 100644 index 00000000..40514a53 --- /dev/null +++ b/alchemiscale/tests/integration/storage/utils.py @@ -0,0 +1,106 @@ +from datetime import datetime + +from gufe.protocols import ProtocolUnitFailure + +from alchemiscale.storage.statestore import Neo4jStore +from alchemiscale import ScopedKey +from alchemiscale.storage.models import TaskStatusEnum, ProtocolDAGResultRef + + +def tasks_are_not_actioned_on_taskhub( + n4js: Neo4jStore, + task_scoped_keys: list[ScopedKey], + taskhub_scoped_key: ScopedKey, +) -> bool: + + actioned_tasks = n4js.get_taskhub_actioned_tasks([taskhub_scoped_key]) + + for task in task_scoped_keys: + if task in actioned_tasks[0].keys(): + return False + return True + + +def tasks_are_errored(n4js: Neo4jStore, task_scoped_keys: list[ScopedKey]) -> bool: + query = """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {_scoped_key: task_scoped_key, status: $error}) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_keys=list(map(str, task_scoped_keys)), + error=TaskStatusEnum.error.value, + ) + + return len(results.records) == len(task_scoped_keys) + + +def tasks_are_waiting(n4js: Neo4jStore, task_scoped_keys: list[ScopedKey]) -> bool: + query = """ + UNWIND $task_scoped_keys as task_scoped_key + MATCH (task:Task {_scoped_key: task_scoped_key, status: $waiting}) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_keys=list(map(str, task_scoped_keys)), + waiting=TaskStatusEnum.waiting.value, + ) + + return len(results.records) == len(task_scoped_keys) + + +def complete_tasks( + n4js: Neo4jStore, + tasks: list[ScopedKey], +): + n4js.set_task_running(tasks) + for task in tasks: + ok_pdrr = ProtocolDAGResultRef( + ok=True, + datetime_created=datetime.utcnow(), + obj_key=task.gufe_key, + scope=task.scope, + ) + + _ = n4js.set_task_result(task, ok_pdrr) + + n4js.set_task_complete(tasks) + + +def fail_task( + n4js: Neo4jStore, + task: ScopedKey, + resolve: bool = False, + error_messages: list[str] = [], +) -> None: + n4js.set_task_running([task]) + + not_ok_pdrr = ProtocolDAGResultRef( + ok=False, + datetime_created=datetime.utcnow(), + obj_key=task.gufe_key, + scope=task.scope, + ) + + protocol_unit_failures = [] + for j, message in enumerate(error_messages): + puf = ProtocolUnitFailure( + source_key=f"FakeProtocolUnitKey-123{j}", + inputs={}, + outputs={}, + exception=RuntimeError, + traceback=message, + ) + protocol_unit_failures.append(puf) + + pdrr_scoped_key = n4js.set_task_result(task, not_ok_pdrr) + + n4js.add_protocol_dag_result_ref_tracebacks(protocol_unit_failures, pdrr_scoped_key) + n4js.set_task_error([task]) + + if resolve: + n4js.resolve_task_restarts([task]) diff --git a/alchemiscale/tests/unit/test_storage_models.py b/alchemiscale/tests/unit/test_storage_models.py index 55dc872f..391a1063 100644 --- a/alchemiscale/tests/unit/test_storage_models.py +++ b/alchemiscale/tests/unit/test_storage_models.py @@ -4,7 +4,7 @@ NetworkStateEnum, NetworkMark, TaskRestartPattern, - Traceback, + Tracebacks, ) from alchemiscale import ScopedKey @@ -137,40 +137,40 @@ def test_from_dict(self): assert trp_reconstructed.taskhub_scoped_key == original_taskhub_scoped_key -class TestTraceback(object): +class TestTracebacks(object): valid_entry = ["traceback1", "traceback2", "traceback3"] tracebacks_value_error = "`tracebacks` must be a non-empty list of string values" def test_empty_string_element(self): with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback(self.valid_entry + [""]) + Tracebacks(self.valid_entry + [""]) def test_non_list_parameter(self): with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback(None) + Tracebacks(None) with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback(100) + Tracebacks(100) with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback("not a list, but still an iterable that yields strings") + Tracebacks("not a list, but still an iterable that yields strings") def test_list_non_string_elements(self): with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback(self.valid_entry + [None]) + Tracebacks(self.valid_entry + [None]) def test_empty_list(self): with pytest.raises(ValueError, match=self.tracebacks_value_error): - Traceback([]) + Tracebacks([]) def test_to_dict(self): - tb = Traceback(self.valid_entry) + tb = Tracebacks(self.valid_entry) tb_dict = tb.to_dict() assert len(tb_dict) == 4 - assert tb_dict.pop("__qualname__") == "Traceback" + assert tb_dict.pop("__qualname__") == "Tracebacks" assert tb_dict.pop("__module__") == "alchemiscale.storage.models" # light test of the version key @@ -184,7 +184,7 @@ def test_to_dict(self): assert expected == tb_dict def test_from_dict(self): - tb_orig = Traceback(self.valid_entry) + tb_orig = Tracebacks(self.valid_entry) tb_dict = tb_orig.to_dict() tb_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(tb_dict)