Skip to content

Commit

Permalink
ctrl+c improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
TheTechromancer committed Jul 28, 2024
1 parent 8e45cb3 commit 2c04d80
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
74 changes: 41 additions & 33 deletions bbot/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from bbot.core import CORE
from bbot.errors import BBOTEngineError
from bbot.core.helpers.misc import rand_string
from bbot.core.helpers.async_helpers import get_event_loop
from bbot.core.helpers.misc import rand_string, in_exception_chain


error_sentinel = object()
Expand All @@ -41,6 +41,7 @@ class EngineBase:
ERROR_CLASS = BBOTEngineError

def __init__(self):
self._shutdown_status = False
self.log = logging.getLogger(f"bbot.core.{self.__class__.__name__.lower()}")

def pickle(self, obj):
Expand All @@ -62,7 +63,7 @@ def unpickle(self, binary):

async def _infinite_retry(self, callback, *args, **kwargs):
interval = kwargs.pop("_interval", 10)
while 1:
while not self._shutdown_status:
try:
return await asyncio.wait_for(callback(*args, **kwargs), timeout=interval)
except (TimeoutError, asyncio.TimeoutError):
Expand Down Expand Up @@ -107,7 +108,6 @@ class EngineClient(EngineBase):
SERVER_CLASS = None

def __init__(self, **kwargs):
self._shutdown = False
super().__init__()
self.name = f"EngineClient {self.__class__.__name__}"
self.process = None
Expand Down Expand Up @@ -135,7 +135,7 @@ def check_error(self, message):
async def run_and_return(self, command, *args, **kwargs):
fn_str = f"{command}({args}, {kwargs})"
self.log.debug(f"{self.name}: executing run-and-return {fn_str}")
if self._shutdown and not command == "_shutdown":
if self._shutdown_status and not command == "_shutdown":
self.log.verbose(f"{self.name} has been shut down and is not accepting new tasks")
return
async with self.new_socket() as socket:
Expand Down Expand Up @@ -163,7 +163,7 @@ async def run_and_return(self, command, *args, **kwargs):
async def run_and_yield(self, command, *args, **kwargs):
fn_str = f"{command}({args}, {kwargs})"
self.log.debug(f"{self.name}: executing run-and-yield {fn_str}")
if self._shutdown:
if self._shutdown_status:
self.log.verbose("Engine has been shut down and is not accepting new tasks")
return
message = self.make_message(command, args=args, kwargs=kwargs)
Expand Down Expand Up @@ -213,14 +213,16 @@ async def send_shutdown_message(self):
async with self.new_socket() as socket:
# -99 == special shutdown message
message = pickle.dumps({"c": -99})
await self._infinite_retry(socket.send, message)
while 1:
response = await self._infinite_retry(socket.recv)
response = pickle.loads(response)
if isinstance(response, dict):
response = response.get("m", "")
if response == "SHUTDOWN_OK":
break
with suppress(TimeoutError, asyncio.TimeoutError):
await asyncio.wait_for(socket.send(message), 0.5)
with suppress(TimeoutError, asyncio.TimeoutError):
while 1:
response = await asyncio.wait_for(socket.recv(), 0.5)
response = pickle.loads(response)
if isinstance(response, dict):
response = response.get("m", "")
if response == "SHUTDOWN_OK":
break

def check_stop(self, message):
if isinstance(message, dict) and len(message) == 1 and "_s" in message:
Expand Down Expand Up @@ -280,7 +282,7 @@ def server_process(server_class, socket_path, **kwargs):
else:
asyncio.run(engine_server.worker())
except (asyncio.CancelledError, KeyboardInterrupt, CancelledError):
pass
return
except Exception:
import traceback

Expand All @@ -306,9 +308,9 @@ async def new_socket(self):
socket.close()

async def shutdown(self):
self.log.debug(f"{self.name}: shutting down...")
if not self._shutdown:
self._shutdown = True
if not self._shutdown_status:
self._shutdown_status = True
self.log.hugewarning(f"{self.name}: shutting down...")
# send shutdown signal
await self.send_shutdown_message()
# then terminate context
Expand Down Expand Up @@ -446,6 +448,7 @@ def check_error(self, message):
return True

async def worker(self):
self.log.debug(f"{self.name}: starting worker")
try:
while 1:
client_id, binary = await self.socket.recv_multipart()
Expand All @@ -462,8 +465,8 @@ async def worker(self):
# -1 == cancel task
if cmd == -1:
self.log.debug(f"{self.name} got cancel signal")
await self.cancel_task(client_id)
await self.send_socket_multipart(client_id, {"m": "CANCEL_OK"})
await self.cancel_task(client_id)
continue

# -99 == shutdown task
Expand Down Expand Up @@ -500,24 +503,28 @@ async def worker(self):
task = asyncio.create_task(coroutine)
self.tasks[client_id] = task, command_fn, args, kwargs
# self.log.debug(f"{self.name}: finished creating task for {command_name}() coroutine")
except Exception as e:
self.log.error(f"{self.name}: error in EngineServer worker: {e}")
self.log.trace(traceback.format_exc())
except BaseException as e:
await self._shutdown()
if not in_exception_chain(e, (KeyboardInterrupt, asyncio.CancelledError)):
self.log.error(f"{self.name}: error in EngineServer worker: {e}")
self.log.trace(traceback.format_exc())
finally:
self.log.debug(f"{self.name}: finished worker()")

async def _shutdown(self):
self.log.debug(f"{self.name}: shutting down...")
await self.cancel_all_tasks()
try:
self.context.destroy(linger=0)
except Exception:
self.log.trace(traceback.format_exc())
try:
self.context.term()
except Exception:
self.log.trace(traceback.format_exc())
self.log.debug(f"{self.name}: finished shutting down")
if not self._shutdown_status:
self.log.critical(f"{self.name}: shutting down...")
self._shutdown_status = True
await self.cancel_all_tasks()
try:
self.context.destroy(linger=0)
except Exception:
self.log.trace(traceback.format_exc())
try:
self.context.term()
except Exception:
self.log.trace(traceback.format_exc())
self.log.debug(f"{self.name}: finished shutting down")

def new_child_task(self, client_id, coro):
task = asyncio.create_task(coro)
Expand Down Expand Up @@ -554,8 +561,9 @@ async def _cancel_task(self, task):
await asyncio.wait_for(task, timeout=10)
except (TimeoutError, asyncio.TimeoutError):
self.log.debug(f"{self.name}: Timeout cancelling task")
return
except (KeyboardInterrupt, asyncio.CancelledError):
pass
return
except BaseException as e:
self.log.error(f"Unhandled error in {task.get_coro().__name__}(): {e}")
self.log.trace(traceback.format_exc())
Expand Down
4 changes: 3 additions & 1 deletion bbot/scanner/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ async def async_start(self):
events, finish = await self.modules["python"]._events_waiting(batch_size=-1)
for e in events:
yield e
if events:
continue

# break if initialization finished and the scan is no longer active
if self._finished_init and self.modules_finished:
Expand Down Expand Up @@ -386,7 +388,7 @@ async def async_start(self):
for task in tasks:
# self.debug(f"Awaiting {task}")
with contextlib.suppress(BaseException):
await task
await asyncio.wait_for(task, timeout=0.1)
self.debug(f"Awaited {len(tasks):,} tasks")
await self._report()
await self._cleanup()
Expand Down

0 comments on commit 2c04d80

Please sign in to comment.