diff --git a/bbot/core/engine.py b/bbot/core/engine.py index fdb4f06e1a..ccd3493c78 100644 --- a/bbot/core/engine.py +++ b/bbot/core/engine.py @@ -551,49 +551,106 @@ async def _shutdown(self): self.log.verbose(f"{self.name}: shutting down...") self._shutdown_status = True await self.cancel_all_tasks() - try: - self.context.destroy(linger=0) - except Exception: - self.log.trace(traceback.format_exc()) - try: - self.context.term() - except Exception: - self.log.trace(traceback.format_exc()) + context = getattr(self, "context", None) + if context is not None: + try: + context.destroy(linger=0) + except Exception: + self.log.trace(traceback.format_exc()) + try: + context.term() + except Exception: + self.log.trace(traceback.format_exc()) self.log.verbose(f"{self.name}: finished shutting down") - def new_child_task(self, client_id, coro): + async def task_pool(self, fn, args_kwargs, threads=10, timeout=300, global_kwargs=None): + if global_kwargs is None: + global_kwargs = {} + + tasks = {} + args_kwargs = list(args_kwargs) + + def new_task(): + if args_kwargs: + kwargs = {} + tracker = None + args = args_kwargs.pop(0) + if isinstance(args, (list, tuple)): + # you can specify a custom tracker value if you want + # this helps with correlating results + with suppress(ValueError): + args, kwargs, tracker = args + # or you can just specify args/kwargs + with suppress(ValueError): + args, kwargs = args + + if not isinstance(kwargs, dict): + raise ValueError(f"kwargs must be dict (got: {kwargs})") + if not isinstance(args, (list, tuple)): + args = [args] + + task = self.new_child_task(fn(*args, **kwargs, **global_kwargs)) + tasks[task] = (args, kwargs, tracker) + + for _ in range(threads): # Start initial batch of tasks + new_task() + + while tasks: # While there are tasks pending + # Wait for the first task to complete + finished = await self.finished_tasks(tasks, timeout=timeout) + for task in finished: + result = task.result() + (args, kwargs, tracker) = tasks.pop(task) + yield (args, kwargs, tracker), result + new_task() + + def new_child_task(self, coro): + """ + Create a new asyncio task, making sure to track it based on the client id. + + This allows the task to be automatically cancelled if its parent is cancelled. + """ + client_id = self.client_id_var.get() task = asyncio.create_task(coro) - def remove_task(t): - tasks = self.child_tasks.get(client_id, set()) - tasks.discard(t) - if not tasks: - self.child_tasks.pop(client_id, None) + if client_id: + + def remove_task(t): + tasks = self.child_tasks.get(client_id, set()) + tasks.discard(t) + if not tasks: + self.child_tasks.pop(client_id, None) + + task.add_done_callback(remove_task) + + try: + self.child_tasks[client_id].add(task) + except KeyError: + self.child_tasks[client_id] = {task} - task.add_done_callback(remove_task) - try: - self.child_tasks[client_id].add(task) - except KeyError: - self.child_tasks[client_id] = {task} return task - async def finished_tasks(self, client_id, timeout=None): - child_tasks = self.child_tasks.get(client_id, set()) - try: - done, pending = await asyncio.wait(child_tasks, return_when=asyncio.FIRST_COMPLETED, timeout=timeout) - except BaseException as e: - if isinstance(e, (TimeoutError, asyncio.exceptions.TimeoutError)): - done = set() - self.log.warning(f"{self.name}: Timeout after {timeout:,} seconds in finished_tasks({child_tasks})") - for task in child_tasks: - task.cancel() - else: - if not in_exception_chain(e, (KeyboardInterrupt, asyncio.CancelledError)): - self.log.error(f"{self.name}: Unhandled exception in finished_tasks({child_tasks}): {e}") - self.log.trace(traceback.format_exc()) - raise - self.child_tasks[client_id] = pending - return done + async def finished_tasks(self, tasks, timeout=None): + """ + Given a list of asyncio tasks, return the ones that are finished with an optional timeout + """ + if tasks: + try: + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED, timeout=timeout) + return done + except BaseException as e: + if isinstance(e, (TimeoutError, asyncio.exceptions.TimeoutError)): + self.log.warning( + f"{self.name}: Timeout after {timeout:,} seconds in finished_tasks({tasks})" + ) + for task in tasks: + task.cancel() + else: + if not in_exception_chain(e, (KeyboardInterrupt, asyncio.CancelledError)): + self.log.error(f"{self.name}: Unhandled exception in finished_tasks({tasks}): {e}") + self.log.trace(traceback.format_exc()) + raise + return set() async def cancel_task(self, client_id): parent_task = self.tasks.pop(client_id, None) diff --git a/bbot/core/helpers/dns/engine.py b/bbot/core/helpers/dns/engine.py index d24c1f766b..8a41c7c8ea 100644 --- a/bbot/core/helpers/dns/engine.py +++ b/bbot/core/helpers/dns/engine.py @@ -349,57 +349,20 @@ async def resolve_batch(self, queries, threads=10, **kwargs): ('www.evilcorp.com', {'1.1.1.1'}) ('evilcorp.com', {'2.2.2.2'}) """ - tasks = {} - client_id = self.client_id_var.get() - - def new_task(query): - task = self.new_child_task(client_id, self.resolve(query, **kwargs)) - tasks[task] = query - - queries = list(queries) - for _ in range(threads): # Start initial batch of tasks - if queries: # Ensure there are args to process - new_task(queries.pop(0)) - - while tasks: # While there are tasks pending - # Wait for the first task to complete - finished = await self.finished_tasks(client_id, timeout=120) - - for task in finished: - results = task.result() - query = tasks.pop(task) - - if results: - yield (query, results) - - if queries: # Start a new task for each one completed, if URLs remain - new_task(queries.pop(0)) + async for (args, _, _), responses in self.task_pool( + self.resolve, args_kwargs=queries, threads=threads, global_kwargs=kwargs + ): + yield args[0], responses async def resolve_raw_batch(self, queries, threads=10, **kwargs): - tasks = {} - client_id = self.client_id_var.get() - - def new_task(query, rdtype): - task = self.new_child_task(client_id, self.resolve_raw(query, type=rdtype, **kwargs)) - tasks[task] = (query, rdtype) - - queries = list(queries) - for _ in range(threads): # Start initial batch of tasks - if queries: # Ensure there are args to process - new_task(*queries.pop(0)) - - while tasks: # While there are tasks pending - # Wait for the first task to complete - finished = await self.finished_tasks(client_id, timeout=120) - - for task in finished: - answers, errors = task.result() - query, rdtype = tasks.pop(task) - for answer in answers: - yield ((query, rdtype), (answer, errors)) - - if queries: # Start a new task for each one completed, if URLs remain - new_task(*queries.pop(0)) + queries_kwargs = [[q[0], {"type": q[1]}] for q in queries] + async for (args, kwargs, _), (answers, errors) in self.task_pool( + self.resolve_raw, args_kwargs=queries_kwargs, threads=threads, global_kwargs=kwargs + ): + query = args[0] + rdtype = kwargs["type"] + for answer in answers: + yield ((query, rdtype), (answer, errors)) async def _catch(self, callback, *args, **kwargs): """ diff --git a/bbot/core/helpers/web/engine.py b/bbot/core/helpers/web/engine.py index 2311378999..7ec79e925a 100644 --- a/bbot/core/helpers/web/engine.py +++ b/bbot/core/helpers/web/engine.py @@ -92,54 +92,17 @@ async def request(self, *args, **kwargs): ) return response - async def request_batch(self, urls, *args, threads=10, **kwargs): - tasks = {} - client_id = self.client_id_var.get() - - urls = list(urls) - - def new_task(): - if urls: - url = urls.pop(0) - task = self.new_child_task(client_id, self.request(url, *args, **kwargs)) - tasks[task] = url - - for _ in range(threads): # Start initial batch of tasks - new_task() - - while tasks: # While there are tasks pending - # Wait for the first task to complete - finished = await self.finished_tasks(client_id, timeout=120) - - for task in finished: - response = task.result() - url = tasks.pop(task) - yield (url, response) - new_task() - - async def request_custom_batch(self, urls_and_kwargs, threads=10): - tasks = {} - client_id = self.client_id_var.get() - urls_and_kwargs = list(urls_and_kwargs) - - def new_task(): - if urls_and_kwargs: # Ensure there are args to process - url, kwargs, custom_tracker = urls_and_kwargs.pop(0) - task = self.new_child_task(client_id, self.request(url, **kwargs)) - tasks[task] = (url, kwargs, custom_tracker) - - for _ in range(threads): # Start initial batch of tasks - new_task() - - while tasks: # While there are tasks pending - # Wait for the first task to complete - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - - for task in done: - response = task.result() - url, kwargs, custom_tracker = tasks.pop(task) - yield (url, kwargs, custom_tracker, response) - new_task() + async def request_batch(self, urls, threads=10, **kwargs): + async for (args, _, _), response in self.task_pool( + self.request, args_kwargs=urls, threads=threads, global_kwargs=kwargs + ): + yield args[0], response + + async def request_custom_batch(self, urls_and_kwargs, threads=10, **kwargs): + async for (args, kwargs, tracker), response in self.task_pool( + self.request, args_kwargs=urls_and_kwargs, threads=threads, global_kwargs=kwargs + ): + yield args[0], kwargs, tracker, response async def download(self, url, **kwargs): warn = kwargs.pop("warn", True)