Skip to content

Commit

Permalink
delete some code
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions committed Aug 23, 2024
1 parent cf34725 commit 49c3c7b
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 133 deletions.
129 changes: 93 additions & 36 deletions bbot/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,49 +551,106 @@ async def _shutdown(self):
self.log.verbose(f"{self.name}: shutting down...")
self._shutdown_status = True
await self.cancel_all_tasks()
try:
self.context.destroy(linger=0)
except Exception:
self.log.trace(traceback.format_exc())
try:
self.context.term()
except Exception:
self.log.trace(traceback.format_exc())
context = getattr(self, "context", None)
if context is not None:
try:
context.destroy(linger=0)
except Exception:
self.log.trace(traceback.format_exc())
try:
context.term()
except Exception:
self.log.trace(traceback.format_exc())
self.log.verbose(f"{self.name}: finished shutting down")

def new_child_task(self, client_id, coro):
async def task_pool(self, fn, args_kwargs, threads=10, timeout=300, global_kwargs=None):
if global_kwargs is None:
global_kwargs = {}

tasks = {}
args_kwargs = list(args_kwargs)

def new_task():
if args_kwargs:
kwargs = {}
tracker = None
args = args_kwargs.pop(0)
if isinstance(args, (list, tuple)):
# you can specify a custom tracker value if you want
# this helps with correlating results
with suppress(ValueError):
args, kwargs, tracker = args
# or you can just specify args/kwargs
with suppress(ValueError):
args, kwargs = args

if not isinstance(kwargs, dict):
raise ValueError(f"kwargs must be dict (got: {kwargs})")
if not isinstance(args, (list, tuple)):
args = [args]

task = self.new_child_task(fn(*args, **kwargs, **global_kwargs))
tasks[task] = (args, kwargs, tracker)

for _ in range(threads): # Start initial batch of tasks
new_task()

while tasks: # While there are tasks pending
# Wait for the first task to complete
finished = await self.finished_tasks(tasks, timeout=timeout)
for task in finished:
result = task.result()
(args, kwargs, tracker) = tasks.pop(task)
yield (args, kwargs, tracker), result
new_task()

def new_child_task(self, coro):
"""
Create a new asyncio task, making sure to track it based on the client id.
This allows the task to be automatically cancelled if its parent is cancelled.
"""
client_id = self.client_id_var.get()
task = asyncio.create_task(coro)

def remove_task(t):
tasks = self.child_tasks.get(client_id, set())
tasks.discard(t)
if not tasks:
self.child_tasks.pop(client_id, None)
if client_id:

def remove_task(t):
tasks = self.child_tasks.get(client_id, set())
tasks.discard(t)
if not tasks:
self.child_tasks.pop(client_id, None)

task.add_done_callback(remove_task)

try:
self.child_tasks[client_id].add(task)
except KeyError:
self.child_tasks[client_id] = {task}

task.add_done_callback(remove_task)
try:
self.child_tasks[client_id].add(task)
except KeyError:
self.child_tasks[client_id] = {task}
return task

async def finished_tasks(self, client_id, timeout=None):
child_tasks = self.child_tasks.get(client_id, set())
try:
done, pending = await asyncio.wait(child_tasks, return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
except BaseException as e:
if isinstance(e, (TimeoutError, asyncio.exceptions.TimeoutError)):
done = set()
self.log.warning(f"{self.name}: Timeout after {timeout:,} seconds in finished_tasks({child_tasks})")
for task in child_tasks:
task.cancel()
else:
if not in_exception_chain(e, (KeyboardInterrupt, asyncio.CancelledError)):
self.log.error(f"{self.name}: Unhandled exception in finished_tasks({child_tasks}): {e}")
self.log.trace(traceback.format_exc())
raise
self.child_tasks[client_id] = pending
return done
async def finished_tasks(self, tasks, timeout=None):
"""
Given a list of asyncio tasks, return the ones that are finished with an optional timeout
"""
if tasks:
try:
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
return done
except BaseException as e:
if isinstance(e, (TimeoutError, asyncio.exceptions.TimeoutError)):
self.log.warning(
f"{self.name}: Timeout after {timeout:,} seconds in finished_tasks({tasks})"
)
for task in tasks:
task.cancel()
else:
if not in_exception_chain(e, (KeyboardInterrupt, asyncio.CancelledError)):
self.log.error(f"{self.name}: Unhandled exception in finished_tasks({tasks}): {e}")
self.log.trace(traceback.format_exc())
raise
return set()

async def cancel_task(self, client_id):
parent_task = self.tasks.pop(client_id, None)
Expand Down
61 changes: 12 additions & 49 deletions bbot/core/helpers/dns/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,57 +349,20 @@ async def resolve_batch(self, queries, threads=10, **kwargs):
('www.evilcorp.com', {'1.1.1.1'})
('evilcorp.com', {'2.2.2.2'})
"""
tasks = {}
client_id = self.client_id_var.get()

def new_task(query):
task = self.new_child_task(client_id, self.resolve(query, **kwargs))
tasks[task] = query

queries = list(queries)
for _ in range(threads): # Start initial batch of tasks
if queries: # Ensure there are args to process
new_task(queries.pop(0))

while tasks: # While there are tasks pending
# Wait for the first task to complete
finished = await self.finished_tasks(client_id, timeout=120)

for task in finished:
results = task.result()
query = tasks.pop(task)

if results:
yield (query, results)

if queries: # Start a new task for each one completed, if URLs remain
new_task(queries.pop(0))
async for (args, _, _), responses in self.task_pool(
self.resolve, args_kwargs=queries, threads=threads, global_kwargs=kwargs
):
yield args[0], responses

async def resolve_raw_batch(self, queries, threads=10, **kwargs):
tasks = {}
client_id = self.client_id_var.get()

def new_task(query, rdtype):
task = self.new_child_task(client_id, self.resolve_raw(query, type=rdtype, **kwargs))
tasks[task] = (query, rdtype)

queries = list(queries)
for _ in range(threads): # Start initial batch of tasks
if queries: # Ensure there are args to process
new_task(*queries.pop(0))

while tasks: # While there are tasks pending
# Wait for the first task to complete
finished = await self.finished_tasks(client_id, timeout=120)

for task in finished:
answers, errors = task.result()
query, rdtype = tasks.pop(task)
for answer in answers:
yield ((query, rdtype), (answer, errors))

if queries: # Start a new task for each one completed, if URLs remain
new_task(*queries.pop(0))
queries_kwargs = [[q[0], {"type": q[1]}] for q in queries]
async for (args, kwargs, _), (answers, errors) in self.task_pool(
self.resolve_raw, args_kwargs=queries_kwargs, threads=threads, global_kwargs=kwargs
):
query = args[0]
rdtype = kwargs["type"]
for answer in answers:
yield ((query, rdtype), (answer, errors))

async def _catch(self, callback, *args, **kwargs):
"""
Expand Down
59 changes: 11 additions & 48 deletions bbot/core/helpers/web/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,54 +92,17 @@ async def request(self, *args, **kwargs):
)
return response

async def request_batch(self, urls, *args, threads=10, **kwargs):
tasks = {}
client_id = self.client_id_var.get()

urls = list(urls)

def new_task():
if urls:
url = urls.pop(0)
task = self.new_child_task(client_id, self.request(url, *args, **kwargs))
tasks[task] = url

for _ in range(threads): # Start initial batch of tasks
new_task()

while tasks: # While there are tasks pending
# Wait for the first task to complete
finished = await self.finished_tasks(client_id, timeout=120)

for task in finished:
response = task.result()
url = tasks.pop(task)
yield (url, response)
new_task()

async def request_custom_batch(self, urls_and_kwargs, threads=10):
tasks = {}
client_id = self.client_id_var.get()
urls_and_kwargs = list(urls_and_kwargs)

def new_task():
if urls_and_kwargs: # Ensure there are args to process
url, kwargs, custom_tracker = urls_and_kwargs.pop(0)
task = self.new_child_task(client_id, self.request(url, **kwargs))
tasks[task] = (url, kwargs, custom_tracker)

for _ in range(threads): # Start initial batch of tasks
new_task()

while tasks: # While there are tasks pending
# Wait for the first task to complete
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

for task in done:
response = task.result()
url, kwargs, custom_tracker = tasks.pop(task)
yield (url, kwargs, custom_tracker, response)
new_task()
async def request_batch(self, urls, threads=10, **kwargs):
async for (args, _, _), response in self.task_pool(
self.request, args_kwargs=urls, threads=threads, global_kwargs=kwargs
):
yield args[0], response

async def request_custom_batch(self, urls_and_kwargs, threads=10, **kwargs):
async for (args, kwargs, tracker), response in self.task_pool(
self.request, args_kwargs=urls_and_kwargs, threads=threads, global_kwargs=kwargs
):
yield args[0], kwargs, tracker, response

async def download(self, url, **kwargs):
warn = kwargs.pop("warn", True)
Expand Down

0 comments on commit 49c3c7b

Please sign in to comment.