From 8a047786def03babb0d3c0dbcd824d2373c12661 Mon Sep 17 00:00:00 2001
From: Ronald Krist <ronald.krist@cesnet.cz>
Date: Mon, 29 Jul 2024 15:13:38 +0200
Subject: [PATCH] fixed current_workflows.set_workflow to commit the changed
 parent record

---
 oarepo_workflows/ext.py                       | 21 ++++++++++++++-----
 .../services/components/workflow.py           |  4 +++-
 tests/test_workflow.py                        | 16 ++++++++++----
 3 files changed, 31 insertions(+), 10 deletions(-)

diff --git a/oarepo_workflows/ext.py b/oarepo_workflows/ext.py
index 10d4699..111c646 100644
--- a/oarepo_workflows/ext.py
+++ b/oarepo_workflows/ext.py
@@ -1,6 +1,7 @@
 from functools import cached_property
 
 import importlib_metadata
+from invenio_drafts_resources.services.records.uow import ParentRecordCommitOp
 
 from oarepo_workflows.errors import InvalidWorkflowError
 from oarepo_workflows.proxies import current_oarepo_workflows
@@ -49,17 +50,27 @@ def set_state(self, identity, record, value, *args, uow=None, **kwargs):
                 identity, record, previous_value, value, *args, uow=uow, **kwargs
             )
 
-    def set_workflow(self, identity, record, value, *args, uow=None, **kwargs):
-        if value not in current_oarepo_workflows.record_workflows:
+    def set_workflow(
+        self, identity, record, new_workflow_id, *args, uow=None, commit=True, **kwargs
+    ):
+        if new_workflow_id not in current_oarepo_workflows.record_workflows:
             raise InvalidWorkflowError(
-                f"Workflow {value} does not exist in the configuration."
+                f"Workflow {new_workflow_id} does not exist in the configuration."
             )
         previous_value = record.parent.workflow
-        record.parent.workflow = value
+        record.parent.workflow = new_workflow_id
         for workflow_changed_notifier in self.workflow_changed_notifiers:
             workflow_changed_notifier(
-                identity, record, previous_value, value, *args, uow=uow, **kwargs
+                identity,
+                record,
+                previous_value,
+                new_workflow_id,
+                *args,
+                uow=uow,
+                **kwargs,
             )
+        if commit:
+            uow.register(ParentRecordCommitOp(record.parent))
 
     def get_workflow_from_record(self, record, **kwargs):
         if hasattr(record, "parent"):
diff --git a/oarepo_workflows/services/components/workflow.py b/oarepo_workflows/services/components/workflow.py
index 4992c3e..13b1838 100644
--- a/oarepo_workflows/services/components/workflow.py
+++ b/oarepo_workflows/services/components/workflow.py
@@ -11,4 +11,6 @@ def create(self, identity, data=None, record=None, **kwargs):
             workflow_id = data["parent"]["workflow_id"]
         except KeyError:
             raise MissingWorkflowError("Workflow not defined in input.")
-        current_oarepo_workflows.set_workflow(identity, record, workflow_id)
+        current_oarepo_workflows.set_workflow(
+            identity, record, workflow_id, uow=self.uow, **kwargs
+        )
diff --git a/tests/test_workflow.py b/tests/test_workflow.py
index 022bc0d..34f574c 100644
--- a/tests/test_workflow.py
+++ b/tests/test_workflow.py
@@ -130,8 +130,12 @@ def test_set_workflow(
 ):
     record = record_service.create(users[0].identity, default_workflow_json)._record
     with pytest.raises(InvalidWorkflowError):
-        workflow_change_function(users[0].identity, record, "egregore")
-    workflow_change_function(users[0].identity, record, "record_owners_can_read")
+        workflow_change_function(
+            users[0].identity, record, "invalid_workflow", commit=False
+        )
+    workflow_change_function(
+        users[0].identity, record, "record_owners_can_read", commit=False
+    )
     assert record.parent.workflow == "record_owners_can_read"
 
 
@@ -153,6 +157,10 @@ def test_set_workflow_entrypoint_hookup(
 ):
     record = record_service.create(users[0].identity, default_workflow_json)._record
     with pytest.raises(InvalidWorkflowError):
-        workflow_change_function(users[0].identity, record, "egregore")
-    workflow_change_function(users[0].identity, record, "record_owners_can_read")
+        workflow_change_function(
+            users[0].identity, record, "invalid_workflow", commit=False
+        )
+    workflow_change_function(
+        users[0].identity, record, "record_owners_can_read", commit=False
+    )
     assert record.parent["workflow-change-notifier-called"]