Skip to content

Commit

Permalink
Merge pull request #1340 from blacklanternsecurity/http-engine
Browse files Browse the repository at this point in the history
HTTP Engine - Offload Web Requests to Dedicated Process
  • Loading branch information
TheTechromancer authored May 16, 2024
2 parents 01bce76 + 41b8cdc commit d73198c
Show file tree
Hide file tree
Showing 40 changed files with 1,333 additions and 675 deletions.
2 changes: 0 additions & 2 deletions bbot/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ async def _main():

# start by creating a default scan preset
preset = Preset(_log=True, name="bbot_cli_main")
# populate preset symlinks
preset.all_presets
# parse command line arguments and merge into preset
try:
preset.parse_args()
Expand Down
16 changes: 8 additions & 8 deletions bbot/core/config/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def include_logger(self, logger):
for handler in self.log_handlers.values():
self.add_log_handler(handler)

def stderr_filter(self, record):
if record.levelno == logging.TRACE and self.log_level > logging.DEBUG:
return False
if record.levelno < self.log_level:
return False
return True

@property
def log_handlers(self):
if self._log_handlers is None:
Expand All @@ -189,16 +196,9 @@ def log_handlers(self):
f"{log_dir}/bbot.debug.log", when="d", interval=1, backupCount=14
)

def stderr_filter(record):
if record.levelno == logging.TRACE and self.log_level > logging.DEBUG:
return False
if record.levelno < self.log_level:
return False
return True

# Log to stderr
stderr_handler = logging.StreamHandler(sys.stderr)
stderr_handler.addFilter(stderr_filter)
stderr_handler.addFilter(self.stderr_filter)
# log to files
debug_handler.addFilter(lambda x: x.levelno == logging.TRACE or (x.levelno < logging.VERBOSE))
main_handler.addFilter(lambda x: x.levelno != logging.TRACE and x.levelno >= logging.VERBOSE)
Expand Down
12 changes: 10 additions & 2 deletions bbot/core/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import logging
from copy import copy
from pathlib import Path
Expand Down Expand Up @@ -143,9 +144,16 @@ def files_config(self):
return self._files_config

def create_process(self, *args, **kwargs):
from .helpers.process import BBOTProcess
if os.environ.get("BBOT_TESTING", "") == "True":
import threading

process = BBOTProcess(*args, **kwargs)
kwargs.pop("custom_name", None)
process = threading.Thread(*args, **kwargs)
else:
from .helpers.process import BBOTProcess

process = BBOTProcess(*args, **kwargs)
process.daemon = True
return process

@property
Expand Down
210 changes: 144 additions & 66 deletions bbot/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import tempfile
import traceback
import zmq.asyncio
import multiprocessing
from pathlib import Path
from contextlib import asynccontextmanager, suppress

from bbot.core import CORE
from bbot.errors import BBOTEngineError
from bbot.core.helpers.misc import rand_string

CMD_EXIT = 1000
Expand All @@ -20,6 +22,9 @@


class EngineBase:

ERROR_CLASS = BBOTEngineError

def __init__(self):
self.log = logging.getLogger(f"bbot.core.{self.__class__.__name__.lower()}")

Expand All @@ -40,16 +45,6 @@ def unpickle(self, binary):
self.log.trace(traceback.format_exc())
return error_sentinel

def check_error(self, message):
if message is error_sentinel:
return True
if isinstance(message, dict) and len(message) == 1 and "_e" in message:
error, trace = message["_e"]
self.log.error(error)
self.log.trace(trace)
return True
return False


class EngineClient(EngineBase):

Expand All @@ -58,6 +53,8 @@ class EngineClient(EngineBase):
def __init__(self, **kwargs):
super().__init__()
self.name = f"EngineClient {self.__class__.__name__}"
self.process = None
self.process_name = multiprocessing.current_process().name
if self.SERVER_CLASS is None:
raise ValueError(f"Must set EngineClient SERVER_CLASS, {self.SERVER_CLASS}")
self.CMDS = dict(self.SERVER_CLASS.CMDS)
Expand All @@ -70,13 +67,27 @@ def __init__(self, **kwargs):
self.context = zmq.asyncio.Context()
atexit.register(self.cleanup)

async def run_and_return(self, command, **kwargs):
def check_error(self, message):
if isinstance(message, dict) and len(message) == 1 and "_e" in message:
error, trace = message["_e"]
error = self.ERROR_CLASS(error)
error.engine_traceback = trace
raise error
return False

async def run_and_return(self, command, *args, **kwargs):
async with self.new_socket() as socket:
message = self.make_message(command, args=kwargs)
if message is error_sentinel:
return
await socket.send(message)
binary = await socket.recv()
try:
message = self.make_message(command, args=args, kwargs=kwargs)
if message is error_sentinel:
return
await socket.send(message)
binary = await socket.recv()
except BaseException:
# -1 == special "cancel" signal
cancel_message = pickle.dumps({"c": -1})
await socket.send(cancel_message)
raise
# self.log.debug(f"{self.name}.{command}({kwargs}) got binary: {binary}")
message = self.unpickle(binary)
self.log.debug(f"{self.name}.{command}({kwargs}) got message: {message}")
Expand All @@ -85,50 +96,64 @@ async def run_and_return(self, command, **kwargs):
return
return message

async def run_and_yield(self, command, **kwargs):
message = self.make_message(command, args=kwargs)
async def run_and_yield(self, command, *args, **kwargs):
message = self.make_message(command, args=args, kwargs=kwargs)
if message is error_sentinel:
return
async with self.new_socket() as socket:
await socket.send(message)
while 1:
binary = await socket.recv()
# self.log.debug(f"{self.name}.{command}({kwargs}) got binary: {binary}")
message = self.unpickle(binary)
self.log.debug(f"{self.name}.{command}({kwargs}) got message: {message}")
# error handling
if self.check_error(message) or self.check_stop(message):
break
yield message
try:
binary = await socket.recv()
# self.log.debug(f"{self.name}.{command}({kwargs}) got binary: {binary}")
message = self.unpickle(binary)
self.log.debug(f"{self.name}.{command}({kwargs}) got message: {message}")
# error handling
if self.check_error(message) or self.check_stop(message):
break
yield message
except GeneratorExit:
# -1 == special "cancel" signal
cancel_message = pickle.dumps({"c": -1})
await socket.send(cancel_message)
raise

def check_stop(self, message):
if isinstance(message, dict) and len(message) == 1 and "_s" in message:
return True
return False

def make_message(self, command, args):
def make_message(self, command, args=None, kwargs=None):
try:
cmd_id = self.CMDS[command]
except KeyError:
raise KeyError(f'Command "{command}" not found. Available commands: {",".join(self.available_commands)}')
return pickle.dumps(dict(c=cmd_id, a=args))
message = {"c": cmd_id}
if args:
message["a"] = args
if kwargs:
message["k"] = kwargs
return pickle.dumps(message)

@property
def available_commands(self):
return [s for s in self.CMDS if isinstance(s, str)]

def start_server(self):
process = CORE.create_process(
target=self.server_process,
args=(
self.SERVER_CLASS,
self.socket_path,
),
kwargs=self.server_kwargs,
custom_name="bbot dnshelper",
)
process.start()
return process
if self.process_name == "MainProcess":
self.process = CORE.create_process(
target=self.server_process,
args=(
self.SERVER_CLASS,
self.socket_path,
),
kwargs=self.server_kwargs,
custom_name="bbot dnshelper",
)
self.process.start()
return self.process
else:
raise BBOTEngineError(f"Tried to start server from process {self.process_name}")

@staticmethod
def server_process(server_class, socket_path, **kwargs):
Expand Down Expand Up @@ -176,36 +201,67 @@ def __init__(self, socket_path):
self.socket = self.context.socket(zmq.ROUTER)
# create socket file
self.socket.bind(f"ipc://{socket_path}")
# task <--> client id mapping
self.tasks = dict()

async def run_and_return(self, client_id, command_fn, **kwargs):
self.log.debug(f"{self.name} run-and-return {command_fn.__name__}({kwargs})")
async def run_and_return(self, client_id, command_fn, *args, **kwargs):
try:
result = await command_fn(**kwargs)
except Exception as e:
error = f"Unhandled error in {self.name}.{command_fn.__name__}({kwargs}): {e}"
trace = traceback.format_exc()
result = {"_e": (error, trace)}
await self.send_socket_multipart([client_id, pickle.dumps(result)])

async def run_and_yield(self, client_id, command_fn, **kwargs):
self.log.debug(f"{self.name} run-and-yield {command_fn.__name__}({kwargs})")
self.log.debug(f"{self.name} run-and-return {command_fn.__name__}({args}, {kwargs})")
try:
result = await command_fn(*args, **kwargs)
except (asyncio.CancelledError, KeyboardInterrupt):
return
except BaseException as e:
error = f"Error in {self.name}.{command_fn.__name__}({args}, {kwargs}): {e}"
trace = traceback.format_exc()
self.log.error(error)
self.log.trace(trace)
result = {"_e": (error, trace)}
finally:
self.tasks.pop(client_id, None)
await self.send_socket_multipart(client_id, result)
except BaseException as e:
self.log.critical(
f"Unhandled exception in {self.name}.run_and_return({client_id}, {command_fn}, {args}, {kwargs}): {e}"
)
self.log.critical(traceback.format_exc())

async def run_and_yield(self, client_id, command_fn, *args, **kwargs):
try:
async for _ in command_fn(**kwargs):
await self.send_socket_multipart([client_id, pickle.dumps(_)])
await self.send_socket_multipart([client_id, pickle.dumps({"_s": None})])
except Exception as e:
error = f"Unhandled error in {self.name}.{command_fn.__name__}({kwargs}): {e}"
trace = traceback.format_exc()
result = {"_e": (error, trace)}
await self.send_socket_multipart([client_id, pickle.dumps(result)])

async def send_socket_multipart(self, *args, **kwargs):
self.log.debug(f"{self.name} run-and-yield {command_fn.__name__}({args}, {kwargs})")
try:
async for _ in command_fn(*args, **kwargs):
await self.send_socket_multipart(client_id, _)
await self.send_socket_multipart(client_id, {"_s": None})
except (asyncio.CancelledError, KeyboardInterrupt):
return
except BaseException as e:
error = f"Error in {self.name}.{command_fn.__name__}({args}, {kwargs}): {e}"
trace = traceback.format_exc()
self.log.error(error)
self.log.trace(trace)
result = {"_e": (error, trace)}
await self.send_socket_multipart(client_id, result)
finally:
self.tasks.pop(client_id, None)
except BaseException as e:
self.log.critical(
f"Unhandled exception in {self.name}.run_and_yield({client_id}, {command_fn}, {args}, {kwargs}): {e}"
)
self.log.critical(traceback.format_exc())

async def send_socket_multipart(self, client_id, message):
try:
await self.socket.send_multipart(*args, **kwargs)
message = pickle.dumps(message)
await self.socket.send_multipart([client_id, message])
except Exception as e:
self.log.warning(f"Error sending ZMQ message: {e}")
self.log.trace(traceback.format_exc())

def check_error(self, message):
if message is error_sentinel:
return True

async def worker(self):
try:
while 1:
Expand All @@ -220,9 +276,30 @@ async def worker(self):
self.log.warning(f"No command sent in message: {message}")
continue

kwargs = message.get("a", {})
if cmd == -1:
task = self.tasks.get(client_id, None)
if task is None:
continue
task, _cmd, _args, _kwargs = task
self.log.debug(f"Cancelling client id {client_id} (task: {task})")
task.cancel()
try:
await task
except (KeyboardInterrupt, asyncio.CancelledError):
pass
except BaseException as e:
self.log.error(f"Unhandled error in {_cmd}({_args}, {_kwargs}): {e}")
self.log.trace(traceback.format_exc())
self.tasks.pop(client_id, None)
continue

args = message.get("a", ())
if not isinstance(args, tuple):
self.log.warning(f"{self.name}: received invalid args of type {type(args)}, should be tuple")
continue
kwargs = message.get("k", {})
if not isinstance(kwargs, dict):
self.log.warning(f"{self.name}: received invalid message of type {type(kwargs)}, should be dict")
self.log.warning(f"{self.name}: received invalid kwargs of type {type(kwargs)}, should be dict")
continue

command_name = self.CMDS[cmd]
Expand All @@ -233,11 +310,12 @@ async def worker(self):
continue

if inspect.isasyncgenfunction(command_fn):
coroutine = self.run_and_yield(client_id, command_fn, **kwargs)
coroutine = self.run_and_yield(client_id, command_fn, *args, **kwargs)
else:
coroutine = self.run_and_return(client_id, command_fn, **kwargs)
coroutine = self.run_and_return(client_id, command_fn, *args, **kwargs)

asyncio.create_task(coroutine)
task = asyncio.create_task(coroutine)
self.tasks[client_id] = task, command_fn, args, kwargs
except Exception as e:
self.log.error(f"Error in EngineServer worker: {e}")
self.log.trace(traceback.format_exc())
Expand Down
2 changes: 2 additions & 0 deletions bbot/core/helpers/dns/dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dns.asyncresolver
from radixtarget import RadixTarget

from bbot.errors import DNSError
from bbot.core.engine import EngineClient
from ..misc import clean_dns_record, is_ip, is_domain, is_dns_name

Expand All @@ -15,6 +16,7 @@
class DNSHelper(EngineClient):

SERVER_CLASS = DNSEngine
ERROR_CLASS = DNSError

"""Helper class for DNS-related operations within BBOT.
Expand Down
Loading

0 comments on commit d73198c

Please sign in to comment.