diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 4ed931276..7b9c78548 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -321,14 +321,24 @@ def spin_until_future_complete( future.add_done_callback(lambda x: self.wake()) if timeout_sec is None or timeout_sec < 0: - while self._context.ok() and not future.done() and not self._is_shutdown: + while ( + self._context.ok() and + not future.done() and + not future.cancelled() + and not self._is_shutdown + ): self._spin_once_until_future_complete(future, timeout_sec) else: start = time.monotonic() end = start + timeout_sec timeout_left = TimeoutObject(timeout_sec) - while self._context.ok() and not future.done() and not self._is_shutdown: + while ( + self._context.ok() and + not future.done() and + not future.cancelled() + and not self._is_shutdown + ): self._spin_once_until_future_complete(future, timeout_left) now = time.monotonic() @@ -610,6 +620,8 @@ def _wait_for_ready_callbacks( with self._tasks_lock: # Get rid of any tasks that are done self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks)) + # Get rid of any tasks that are cancelled + self._tasks = list(filter(lambda t_e_n: not t_e_n[0].cancelled(), self._tasks)) # Gather entities that can be waited on subscriptions: List[Subscription] = [] diff --git a/rclpy/rclpy/task.py b/rclpy/rclpy/task.py index 81a56ab5b..9ec996443 100644 --- a/rclpy/rclpy/task.py +++ b/rclpy/rclpy/task.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import auto, StrEnum import inspect import sys import threading @@ -31,14 +32,19 @@ def _fake_weakref() -> None: return None +class FutureState(StrEnum): + """States defining the lifecycle of a future.""" + + PENDING = auto() + CANCELLED = auto() + FINISHED = auto() + + class Future(Generic[T]): """Represent the outcome of a task in the future.""" def __init__(self, *, executor: Optional['Executor'] = None) -> None: - # true if the task is done or cancelled - self._done = False - # true if the task is cancelled - self._cancelled = False + self._state = FutureState.PENDING # the final return value of the handler self._result: Optional[T] = None # An exception raised by the handler when called @@ -61,15 +67,16 @@ def __del__(self) -> None: def __await__(self) -> Generator[None, None, Optional[T]]: # Yield if the task is not finished - while not self._done: + while not self.done(): yield return self.result() def cancel(self) -> None: """Request cancellation of the running task if it is not done already.""" with self._lock: - if not self._done: - self._cancelled = True + if not self.done(): + self._state = FutureState.CANCELLED + self._schedule_or_invoke_done_callbacks() def cancelled(self) -> bool: @@ -78,15 +85,15 @@ def cancelled(self) -> bool: :return: True if the task was cancelled """ - return self._cancelled + return self._state == FutureState.CANCELLED def done(self) -> bool: """ - Indicate if the task has finished executing. + Indicate if the task has finished or cancelled executing. - :return: True if the task is finished or raised while it was executing + :return: True if the task is finished, cancelled or raised while it was executing """ - return self._done + return self._state != FutureState.PENDING def result(self) -> Optional[T]: """ @@ -118,8 +125,7 @@ def set_result(self, result: T) -> None: """ with self._lock: self._result = result - self._done = True - self._cancelled = False + self._state = FutureState.FINISHED self._schedule_or_invoke_done_callbacks() def set_exception(self, exception: Exception) -> None: @@ -131,8 +137,7 @@ def set_exception(self, exception: Exception) -> None: with self._lock: self._exception = exception self._exception_fetched = False - self._done = True - self._cancelled = False + self._state = FutureState.FINISHED self._schedule_or_invoke_done_callbacks() def _schedule_or_invoke_done_callbacks(self) -> None: @@ -181,7 +186,7 @@ def add_done_callback(self, callback: Callable[['Future[T]'], None]) -> None: """ invoke = False with self._lock: - if self._done: + if self.done(): assert self._executor is not None executor = self._executor() if executor is not None: @@ -239,10 +244,14 @@ def __call__(self) -> None: The return value of the handler is stored as the task result. """ - if self._done or self._executing or not self._task_lock.acquire(blocking=False): + if ( + self.done() or + self._executing or + not self._task_lock.acquire(blocking=False) + ): return try: - if self._done: + if self.done(): return self._executing = True @@ -285,3 +294,9 @@ def executing(self) -> bool: :return: True if the task is currently executing. """ return self._executing + + def cancel(self) -> None: + if not self.done() and inspect.iscoroutine(self._handler): + self._handler.close() + + super().cancel() diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index 7b7027486..c27c6661d 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -273,6 +273,26 @@ async def coroutine(): self.assertTrue(future.done()) self.assertEqual('Sentinel Result', future.result()) + def test_create_task_coroutine_cancel(self) -> None: + self.assertIsNotNone(self.node.handle) + executor = SingleThreadedExecutor(context=self.context) + executor.add_node(self.node) + + async def coroutine(): + return 'Sentinel Result' + + future = executor.create_task(coroutine) + self.assertFalse(future.done()) + self.assertFalse(future.cancelled()) + + future.cancel() + self.assertTrue(future.cancelled()) + + executor.spin_until_future_complete(future) + self.assertTrue(future.done()) + self.assertTrue(future.cancelled()) + self.assertEqual(None, future.result()) + def test_create_task_normal_function(self) -> None: self.assertIsNotNone(self.node.handle) executor = SingleThreadedExecutor(context=self.context)