From fbd90ac9f67e021d8b6e87d29d10e21c9d70d01e Mon Sep 17 00:00:00 2001
From: superstar54 <xingwang1991@gmail.com>
Date: Fri, 13 Dec 2024 06:39:06 +0100
Subject: [PATCH] To provide a better user experience, we raise an exception
 explicitly when the timeout is exceeded in the wait method.

---
 src/aiida_workgraph/workgraph.py | 26 +++++++++++++++++++++-----
 1 file changed, 21 insertions(+), 5 deletions(-)

diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py
index 1181ed17..dba249e0 100644
--- a/src/aiida_workgraph/workgraph.py
+++ b/src/aiida_workgraph/workgraph.py
@@ -110,14 +110,14 @@ def submit(
         self,
         inputs: Optional[Dict[str, Any]] = None,
         wait: bool = False,
-        timeout: int = 60,
+        timeout: int = 600,
         interval: int = 5,
         metadata: Optional[Dict[str, Any]] = None,
     ) -> aiida.orm.ProcessNode:
         """Submit the AiiDA workgraph process and optionally wait for it to finish.
         Args:
             wait (bool): Wait for the process to finish.
-            timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 60.
+            timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 600.
             restart (bool): Restart the process, and reset the modified tasks, then only re-run the modified tasks.
             new (bool): Submit a new process.
         """
@@ -228,11 +228,17 @@ def get_error_handlers(self) -> Dict[str, Any]:
                 task["exit_codes"] = exit_codes
         return error_handlers
 
-    def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None:
+    def wait(self, timeout: int = 600, tasks: dict = None, interval: int = 5) -> None:
         """
         Periodically checks and waits for the AiiDA workgraph process to finish until a given timeout.
+
         Args:
-            timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 50.
+            timeout (int): The maximum time in seconds to wait for the process to finish. Defaults to 600.
+            tasks (dict): Optional; specifies task states to wait for in the format {task_name: [acceptable_states]}.
+            interval (int): The time interval in seconds between checks. Defaults to 5.
+
+        Raises:
+            TimeoutError: If the process does not finish within the given timeout.
         """
         terminating_states = (
             "KILLED",
@@ -245,8 +251,10 @@ def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None
         start = time.time()
         self.update()
         finished = False
+
         while not finished:
             self.update()
+
             if tasks is not None:
                 states = []
                 for name, value in tasks.items():
@@ -255,9 +263,17 @@ def wait(self, timeout: int = 50, tasks: dict = None, interval: int = 5) -> None
                 finished = all(states)
             else:
                 finished = self.state in terminating_states
+
+            if finished:
+                print(f"Process {self.process.pk} finished with state: {self.state}")
+                return
+
             time.sleep(interval)
+
             if time.time() - start > timeout:
-                break
+                raise TimeoutError(
+                    f"Timeout reached after {timeout} seconds while waiting for the WorkGraph: {self.process.pk}. "
+                )
 
     def update(self) -> None:
         """