diff --git a/oarepo_workflows/ext.py b/oarepo_workflows/ext.py index 10d4699..111c646 100644 --- a/oarepo_workflows/ext.py +++ b/oarepo_workflows/ext.py @@ -1,6 +1,7 @@ from functools import cached_property import importlib_metadata +from invenio_drafts_resources.services.records.uow import ParentRecordCommitOp from oarepo_workflows.errors import InvalidWorkflowError from oarepo_workflows.proxies import current_oarepo_workflows @@ -49,17 +50,27 @@ def set_state(self, identity, record, value, *args, uow=None, **kwargs): identity, record, previous_value, value, *args, uow=uow, **kwargs ) - def set_workflow(self, identity, record, value, *args, uow=None, **kwargs): - if value not in current_oarepo_workflows.record_workflows: + def set_workflow( + self, identity, record, new_workflow_id, *args, uow=None, commit=True, **kwargs + ): + if new_workflow_id not in current_oarepo_workflows.record_workflows: raise InvalidWorkflowError( - f"Workflow {value} does not exist in the configuration." + f"Workflow {new_workflow_id} does not exist in the configuration." ) previous_value = record.parent.workflow - record.parent.workflow = value + record.parent.workflow = new_workflow_id for workflow_changed_notifier in self.workflow_changed_notifiers: workflow_changed_notifier( - identity, record, previous_value, value, *args, uow=uow, **kwargs + identity, + record, + previous_value, + new_workflow_id, + *args, + uow=uow, + **kwargs, ) + if commit: + uow.register(ParentRecordCommitOp(record.parent)) def get_workflow_from_record(self, record, **kwargs): if hasattr(record, "parent"): diff --git a/oarepo_workflows/services/components/workflow.py b/oarepo_workflows/services/components/workflow.py index 4992c3e..13b1838 100644 --- a/oarepo_workflows/services/components/workflow.py +++ b/oarepo_workflows/services/components/workflow.py @@ -11,4 +11,6 @@ def create(self, identity, data=None, record=None, **kwargs): workflow_id = data["parent"]["workflow_id"] except KeyError: raise MissingWorkflowError("Workflow not defined in input.") - current_oarepo_workflows.set_workflow(identity, record, workflow_id) + current_oarepo_workflows.set_workflow( + identity, record, workflow_id, uow=self.uow, **kwargs + ) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 022bc0d..34f574c 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -130,8 +130,12 @@ def test_set_workflow( ): record = record_service.create(users[0].identity, default_workflow_json)._record with pytest.raises(InvalidWorkflowError): - workflow_change_function(users[0].identity, record, "egregore") - workflow_change_function(users[0].identity, record, "record_owners_can_read") + workflow_change_function( + users[0].identity, record, "invalid_workflow", commit=False + ) + workflow_change_function( + users[0].identity, record, "record_owners_can_read", commit=False + ) assert record.parent.workflow == "record_owners_can_read" @@ -153,6 +157,10 @@ def test_set_workflow_entrypoint_hookup( ): record = record_service.create(users[0].identity, default_workflow_json)._record with pytest.raises(InvalidWorkflowError): - workflow_change_function(users[0].identity, record, "egregore") - workflow_change_function(users[0].identity, record, "record_owners_can_read") + workflow_change_function( + users[0].identity, record, "invalid_workflow", commit=False + ) + workflow_change_function( + users[0].identity, record, "record_owners_can_read", commit=False + ) assert record.parent["workflow-change-notifier-called"]