From 2c04d806fc284e1be7404cf3b0ed7356cd0f1ee6 Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sun, 28 Jul 2024 18:14:11 -0400 Subject: [PATCH] ctrl+c improvement --- bbot/core/engine.py | 74 +++++++++++++++++++++++------------------ bbot/scanner/scanner.py | 4 ++- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/bbot/core/engine.py b/bbot/core/engine.py index 20ef59a4a..c3897cbef 100644 --- a/bbot/core/engine.py +++ b/bbot/core/engine.py @@ -16,8 +16,8 @@ from bbot.core import CORE from bbot.errors import BBOTEngineError -from bbot.core.helpers.misc import rand_string from bbot.core.helpers.async_helpers import get_event_loop +from bbot.core.helpers.misc import rand_string, in_exception_chain error_sentinel = object() @@ -41,6 +41,7 @@ class EngineBase: ERROR_CLASS = BBOTEngineError def __init__(self): + self._shutdown_status = False self.log = logging.getLogger(f"bbot.core.{self.__class__.__name__.lower()}") def pickle(self, obj): @@ -62,7 +63,7 @@ def unpickle(self, binary): async def _infinite_retry(self, callback, *args, **kwargs): interval = kwargs.pop("_interval", 10) - while 1: + while not self._shutdown_status: try: return await asyncio.wait_for(callback(*args, **kwargs), timeout=interval) except (TimeoutError, asyncio.TimeoutError): @@ -107,7 +108,6 @@ class EngineClient(EngineBase): SERVER_CLASS = None def __init__(self, **kwargs): - self._shutdown = False super().__init__() self.name = f"EngineClient {self.__class__.__name__}" self.process = None @@ -135,7 +135,7 @@ def check_error(self, message): async def run_and_return(self, command, *args, **kwargs): fn_str = f"{command}({args}, {kwargs})" self.log.debug(f"{self.name}: executing run-and-return {fn_str}") - if self._shutdown and not command == "_shutdown": + if self._shutdown_status and not command == "_shutdown": self.log.verbose(f"{self.name} has been shut down and is not accepting new tasks") return async with self.new_socket() as socket: @@ -163,7 +163,7 @@ async def run_and_return(self, command, *args, **kwargs): async def run_and_yield(self, command, *args, **kwargs): fn_str = f"{command}({args}, {kwargs})" self.log.debug(f"{self.name}: executing run-and-yield {fn_str}") - if self._shutdown: + if self._shutdown_status: self.log.verbose("Engine has been shut down and is not accepting new tasks") return message = self.make_message(command, args=args, kwargs=kwargs) @@ -213,14 +213,16 @@ async def send_shutdown_message(self): async with self.new_socket() as socket: # -99 == special shutdown message message = pickle.dumps({"c": -99}) - await self._infinite_retry(socket.send, message) - while 1: - response = await self._infinite_retry(socket.recv) - response = pickle.loads(response) - if isinstance(response, dict): - response = response.get("m", "") - if response == "SHUTDOWN_OK": - break + with suppress(TimeoutError, asyncio.TimeoutError): + await asyncio.wait_for(socket.send(message), 0.5) + with suppress(TimeoutError, asyncio.TimeoutError): + while 1: + response = await asyncio.wait_for(socket.recv(), 0.5) + response = pickle.loads(response) + if isinstance(response, dict): + response = response.get("m", "") + if response == "SHUTDOWN_OK": + break def check_stop(self, message): if isinstance(message, dict) and len(message) == 1 and "_s" in message: @@ -280,7 +282,7 @@ def server_process(server_class, socket_path, **kwargs): else: asyncio.run(engine_server.worker()) except (asyncio.CancelledError, KeyboardInterrupt, CancelledError): - pass + return except Exception: import traceback @@ -306,9 +308,9 @@ async def new_socket(self): socket.close() async def shutdown(self): - self.log.debug(f"{self.name}: shutting down...") - if not self._shutdown: - self._shutdown = True + if not self._shutdown_status: + self._shutdown_status = True + self.log.hugewarning(f"{self.name}: shutting down...") # send shutdown signal await self.send_shutdown_message() # then terminate context @@ -446,6 +448,7 @@ def check_error(self, message): return True async def worker(self): + self.log.debug(f"{self.name}: starting worker") try: while 1: client_id, binary = await self.socket.recv_multipart() @@ -462,8 +465,8 @@ async def worker(self): # -1 == cancel task if cmd == -1: self.log.debug(f"{self.name} got cancel signal") - await self.cancel_task(client_id) await self.send_socket_multipart(client_id, {"m": "CANCEL_OK"}) + await self.cancel_task(client_id) continue # -99 == shutdown task @@ -500,24 +503,28 @@ async def worker(self): task = asyncio.create_task(coroutine) self.tasks[client_id] = task, command_fn, args, kwargs # self.log.debug(f"{self.name}: finished creating task for {command_name}() coroutine") - except Exception as e: - self.log.error(f"{self.name}: error in EngineServer worker: {e}") - self.log.trace(traceback.format_exc()) + except BaseException as e: + await self._shutdown() + if not in_exception_chain(e, (KeyboardInterrupt, asyncio.CancelledError)): + self.log.error(f"{self.name}: error in EngineServer worker: {e}") + self.log.trace(traceback.format_exc()) finally: self.log.debug(f"{self.name}: finished worker()") async def _shutdown(self): - self.log.debug(f"{self.name}: shutting down...") - await self.cancel_all_tasks() - try: - self.context.destroy(linger=0) - except Exception: - self.log.trace(traceback.format_exc()) - try: - self.context.term() - except Exception: - self.log.trace(traceback.format_exc()) - self.log.debug(f"{self.name}: finished shutting down") + if not self._shutdown_status: + self.log.critical(f"{self.name}: shutting down...") + self._shutdown_status = True + await self.cancel_all_tasks() + try: + self.context.destroy(linger=0) + except Exception: + self.log.trace(traceback.format_exc()) + try: + self.context.term() + except Exception: + self.log.trace(traceback.format_exc()) + self.log.debug(f"{self.name}: finished shutting down") def new_child_task(self, client_id, coro): task = asyncio.create_task(coro) @@ -554,8 +561,9 @@ async def _cancel_task(self, task): await asyncio.wait_for(task, timeout=10) except (TimeoutError, asyncio.TimeoutError): self.log.debug(f"{self.name}: Timeout cancelling task") + return except (KeyboardInterrupt, asyncio.CancelledError): - pass + return except BaseException as e: self.log.error(f"Unhandled error in {task.get_coro().__name__}(): {e}") self.log.trace(traceback.format_exc()) diff --git a/bbot/scanner/scanner.py b/bbot/scanner/scanner.py index d90b8c329..0fe4191bf 100644 --- a/bbot/scanner/scanner.py +++ b/bbot/scanner/scanner.py @@ -353,6 +353,8 @@ async def async_start(self): events, finish = await self.modules["python"]._events_waiting(batch_size=-1) for e in events: yield e + if events: + continue # break if initialization finished and the scan is no longer active if self._finished_init and self.modules_finished: @@ -386,7 +388,7 @@ async def async_start(self): for task in tasks: # self.debug(f"Awaiting {task}") with contextlib.suppress(BaseException): - await task + await asyncio.wait_for(task, timeout=0.1) self.debug(f"Awaited {len(tasks):,} tasks") await self._report() await self._cleanup()