diff --git a/rclpy/rclpy/__init__.py b/rclpy/rclpy/__init__.py index 6e94dce51..d53d66797 100644 --- a/rclpy/rclpy/__init__.py +++ b/rclpy/rclpy/__init__.py @@ -40,8 +40,10 @@ This will invalidate all entities derived from the context. """ +from types import TracebackType from typing import List from typing import Optional +from typing import Type from typing import TYPE_CHECKING from rclpy.context import Context @@ -62,13 +64,52 @@ from rclpy.node import Node # noqa: F401 +class InitContextManager: + """ + A proxy object for initialization. + + One of these is returned when calling `rclpy.init`, and can be used with context managers to + properly cleanup after initialization. + """ + + def __init__(self, + args: Optional[List[str]], + context: Optional[Context], + domain_id: Optional[int], + signal_handler_options: Optional[SignalHandlerOptions]) -> None: + self.context = get_default_context() if context is None else context + if signal_handler_options is None: + if context is None or context is get_default_context(): + signal_handler_options = SignalHandlerOptions.ALL + else: + signal_handler_options = SignalHandlerOptions.NO + + if signal_handler_options == SignalHandlerOptions.NO: + self.installed_signal_handlers = False + else: + self.installed_signal_handlers = True + install_signal_handlers(signal_handler_options) + self.context.init(args, domain_id=domain_id) + + def __enter__(self) -> 'InitContextManager': + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + shutdown(context=self.context, uninstall_handlers=self.installed_signal_handlers) + + def init( *, args: Optional[List[str]] = None, context: Optional[Context] = None, domain_id: Optional[int] = None, signal_handler_options: Optional[SignalHandlerOptions] = None, -) -> None: +) -> InitContextManager: """ Initialize ROS communications for a given context. @@ -78,15 +119,9 @@ def init( :param domain_id: ROS domain id. :param signal_handler_options: Indicate which signal handlers to install. If `None`, SIGINT and SIGTERM will be installed when initializing the default context. + :return: an InitContextManager that can be used with Python context managers to cleanup. """ - context = get_default_context() if context is None else context - if signal_handler_options is None: - if context is None or context is get_default_context(): - signal_handler_options = SignalHandlerOptions.ALL - else: - signal_handler_options = SignalHandlerOptions.NO - install_signal_handlers(signal_handler_options) - return context.init(args, domain_id=domain_id) + return InitContextManager(args, context, domain_id, signal_handler_options) # The global spin functions need an executor to do the work @@ -125,7 +160,7 @@ def shutdown( :param uninstall_handlers: If `None`, signal handlers will be uninstalled when shutting down the default context. If `True`, signal handlers will be uninstalled. - If not, signal handlers won't be uninstalled. + If `False`, signal handlers won't be uninstalled. """ _shutdown(context=context) if ( diff --git a/rclpy/rclpy/context.py b/rclpy/rclpy/context.py index c3a27953e..5563b09b1 100644 --- a/rclpy/rclpy/context.py +++ b/rclpy/rclpy/context.py @@ -22,11 +22,16 @@ from typing import Optional from typing import Protocol from typing import Type +from typing import TYPE_CHECKING from typing import Union -from weakref import WeakMethod +import warnings +import weakref from rclpy.destroyable import DestroyableType +if TYPE_CHECKING: + from rclpy.node import Node + class ContextHandle(DestroyableType, Protocol): @@ -60,10 +65,11 @@ class Context(ContextManager['Context']): """ def __init__(self) -> None: - self._lock = threading.Lock() - self._callbacks: List[Union['WeakMethod[MethodType]', Callable[[], None]]] = [] + self._lock = threading.RLock() + self._callbacks: List[Union['weakref.WeakMethod[MethodType]', Callable[[], None]]] = [] self._logging_initialized = False self.__context: Optional[ContextHandle] = None + self.__node_weak_ref_list: List[weakref.ReferenceType['Node']] = [] @property def handle(self) -> Optional[ContextHandle]: @@ -82,6 +88,10 @@ def init(self, Initialize ROS communications for a given context. :param args: List of command line arguments. + :param initialize_logging: Whether to initialize logging for the whole process. + The default is to initialize logging. + :param domain_id: Which domain ID to use for this context. + If None (the default), domain ID 0 is used. """ # imported locally to avoid loading extensions on module import from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy @@ -106,6 +116,36 @@ def init(self, _rclpy.rclpy_logging_configure(self.__context) self._logging_initialized = True + def track_node(self, node: 'Node') -> None: + """ + Track a Node associated with this Context. + + When the Context is destroyed, it will destroy every Node it tracks. + + :param node: The node to take a weak reference to. + """ + with self._lock: + self.__node_weak_ref_list.append(weakref.ref(node)) + + def untrack_node(self, node: 'Node') -> None: + """ + Stop tracking a Node associated with this Context. + + If a Node is destroyed before the context, we no longer need to track it for destruction of + the Context, so remove it here. + """ + with self._lock: + for index, weak_node in enumerate(self.__node_weak_ref_list): + node_in_list = weak_node() + if node_in_list is node: + found_index = index + break + else: + # Odd that we didn't find the node in the list, but just get out + return + + del self.__node_weak_ref_list[found_index] + def ok(self) -> bool: """Check if context hasn't been shut down.""" with self._lock: @@ -121,14 +161,22 @@ def _call_on_shutdown_callbacks(self) -> None: callback() self._callbacks = [] + def _cleanup(self) -> None: + for weak_node in self.__node_weak_ref_list: + node = weak_node() + if node is not None: + node.destroy_node() + + self.__context.shutdown() + self._call_on_shutdown_callbacks() + self._logging_fini() + def shutdown(self) -> None: """Shutdown this context.""" if self.__context is None: raise RuntimeError('Context must be initialized before it can be shutdown') with self.__context, self._lock: - self.__context.shutdown() - self._call_on_shutdown_callbacks() - self._logging_fini() + self._cleanup() def try_shutdown(self) -> None: """Shutdown this context, if not already shutdown.""" @@ -136,11 +184,9 @@ def try_shutdown(self) -> None: return with self.__context, self._lock: if self.__context.ok(): - self.__context.shutdown() - self._call_on_shutdown_callbacks() - self._logging_fini() + self._cleanup() - def _remove_callback(self, weak_method: 'WeakMethod[MethodType]') -> None: + def _remove_callback(self, weak_method: 'weakref.WeakMethod[MethodType]') -> None: self._callbacks.remove(weak_method) def on_shutdown(self, callback: Callable[[], None]) -> None: @@ -151,7 +197,7 @@ def on_shutdown(self, callback: Callable[[], None]) -> None: if self.__context is None: with self._lock: if ismethod(callback): - self._callbacks.append(WeakMethod(callback, self._remove_callback)) + self._callbacks.append(weakref.WeakMethod(callback, self._remove_callback)) else: self._callbacks.append(callback) return @@ -161,7 +207,7 @@ def on_shutdown(self, callback: Callable[[], None]) -> None: callback() else: if ismethod(callback): - self._callbacks.append(WeakMethod(callback, self._remove_callback)) + self._callbacks.append(weakref.WeakMethod(callback, self._remove_callback)) else: self._callbacks.append(callback) @@ -187,9 +233,12 @@ def get_domain_id(self) -> int: return self.__context.get_domain_id() def __enter__(self) -> 'Context': - # We do not accept parameters here. If one wants to customize the init() call, - # they would have to call it manually and not use the ContextManager convenience - self.init() + if self.__context is None: + # init() hasn't been called yet; for backwards compatibility, initialize and warn + warnings.warn('init() must be called on a Context before using it in a Python context ' + 'manager. Calling init() with no arguments, this usage is deprecated') + self.init() + return self def __exit__( diff --git a/rclpy/rclpy/node.py b/rclpy/rclpy/node.py index 6a3b4d843..3b7fd37b9 100644 --- a/rclpy/rclpy/node.py +++ b/rclpy/rclpy/node.py @@ -247,6 +247,8 @@ def __init__( self._type_description_service = TypeDescriptionService(self) + self._context.track_node(self) + @property def publishers(self) -> Iterator[Publisher]: """Get publishers that have been created on this node.""" @@ -1944,6 +1946,8 @@ def destroy_node(self): * :func:`create_guard_condition` """ + self._context.untrack_node(self) + # Drop extra reference to parameter event publisher. # It will be destroyed with other publishers below. self._parameter_event_publisher = None diff --git a/rclpy/test/test_context.py b/rclpy/test/test_context.py index f1335ef9f..7831597fd 100644 --- a/rclpy/test/test_context.py +++ b/rclpy/test/test_context.py @@ -61,6 +61,8 @@ def test_context_manager(): assert not context.ok(), 'the context should not be ok() before init() is called' + context.init() + with context as the_context: # Make sure the correct instance is returned assert the_context is context diff --git a/rclpy/test/test_init_shutdown.py b/rclpy/test/test_init_shutdown.py index 633142dec..43a514448 100644 --- a/rclpy/test/test_init_shutdown.py +++ b/rclpy/test/test_init_shutdown.py @@ -109,3 +109,9 @@ def test_signal_handlers(): def test_init_with_invalid_domain_id(): with pytest.raises(RuntimeError): rclpy.init(domain_id=-1) + + +def test_managed_init(): + with rclpy.init(domain_id=123) as init: + assert init.context.get_domain_id() == 123 + assert init.context.ok()