-
Notifications
You must be signed in to change notification settings - Fork 562
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1231 from blacklanternsecurity/dns-engine
DNS Engine - Offload DNS to Dedicated Process
- Loading branch information
Showing
90 changed files
with
2,664 additions
and
1,970 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import sys | ||
import multiprocessing as mp | ||
|
||
try: | ||
mp.set_start_method("spawn") | ||
except Exception: | ||
start_method = mp.get_start_method() | ||
if start_method != "spawn": | ||
print( | ||
f"[WARN] Multiprocessing spawn method is set to {start_method}. This may negatively affect performance.", | ||
file=sys.stderr, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import zmq | ||
import atexit | ||
import pickle | ||
import asyncio | ||
import inspect | ||
import logging | ||
import tempfile | ||
import traceback | ||
import zmq.asyncio | ||
from pathlib import Path | ||
from contextlib import asynccontextmanager, suppress | ||
|
||
from bbot.core import CORE | ||
from bbot.core.helpers.misc import rand_string | ||
|
||
CMD_EXIT = 1000 | ||
|
||
|
||
class EngineClient: | ||
|
||
SERVER_CLASS = None | ||
|
||
def __init__(self, **kwargs): | ||
self.name = f"EngineClient {self.__class__.__name__}" | ||
if self.SERVER_CLASS is None: | ||
raise ValueError(f"Must set EngineClient SERVER_CLASS, {self.SERVER_CLASS}") | ||
self.CMDS = dict(self.SERVER_CLASS.CMDS) | ||
for k, v in list(self.CMDS.items()): | ||
self.CMDS[v] = k | ||
self.log = logging.getLogger(f"bbot.core.{self.__class__.__name__.lower()}") | ||
self.socket_address = f"zmq_{rand_string(8)}.sock" | ||
self.socket_path = Path(tempfile.gettempdir()) / self.socket_address | ||
self.server_kwargs = kwargs.pop("server_kwargs", {}) | ||
self._server_process = None | ||
self.context = zmq.asyncio.Context() | ||
atexit.register(self.cleanup) | ||
|
||
async def run_and_return(self, command, **kwargs): | ||
async with self.new_socket() as socket: | ||
message = self.make_message(command, args=kwargs) | ||
await socket.send(message) | ||
binary = await socket.recv() | ||
# self.log.debug(f"{self.name}.{command}({kwargs}) got binary: {binary}") | ||
message = pickle.loads(binary) | ||
self.log.debug(f"{self.name}.{command}({kwargs}) got message: {message}") | ||
# error handling | ||
if self.check_error(message): | ||
return | ||
return message | ||
|
||
async def run_and_yield(self, command, **kwargs): | ||
message = self.make_message(command, args=kwargs) | ||
async with self.new_socket() as socket: | ||
await socket.send(message) | ||
while 1: | ||
binary = await socket.recv() | ||
# self.log.debug(f"{self.name}.{command}({kwargs}) got binary: {binary}") | ||
message = pickle.loads(binary) | ||
self.log.debug(f"{self.name}.{command}({kwargs}) got message: {message}") | ||
# error handling | ||
if self.check_error(message) or self.check_stop(message): | ||
break | ||
yield message | ||
|
||
def check_error(self, message): | ||
if isinstance(message, dict) and len(message) == 1 and "_e" in message: | ||
error, trace = message["_e"] | ||
self.log.error(error) | ||
self.log.trace(trace) | ||
return True | ||
return False | ||
|
||
def check_stop(self, message): | ||
if isinstance(message, dict) and len(message) == 1 and "_s" in message: | ||
return True | ||
return False | ||
|
||
def make_message(self, command, args): | ||
try: | ||
cmd_id = self.CMDS[command] | ||
except KeyError: | ||
raise KeyError(f'Command "{command}" not found. Available commands: {",".join(self.available_commands)}') | ||
return pickle.dumps(dict(c=cmd_id, a=args)) | ||
|
||
@property | ||
def available_commands(self): | ||
return [s for s in self.CMDS if isinstance(s, str)] | ||
|
||
def start_server(self): | ||
process = CORE.create_process( | ||
target=self.server_process, | ||
args=( | ||
self.SERVER_CLASS, | ||
self.socket_path, | ||
), | ||
kwargs=self.server_kwargs, | ||
) | ||
process.start() | ||
return process | ||
|
||
@staticmethod | ||
def server_process(server_class, socket_path, **kwargs): | ||
try: | ||
engine_server = server_class(socket_path, **kwargs) | ||
asyncio.run(engine_server.worker()) | ||
except (asyncio.CancelledError, KeyboardInterrupt): | ||
pass | ||
except Exception: | ||
import traceback | ||
|
||
log = logging.getLogger("bbot.core.engine.server") | ||
log.critical(f"Unhandled error in {server_class.__name__} server process: {traceback.format_exc()}") | ||
|
||
@asynccontextmanager | ||
async def new_socket(self): | ||
if self._server_process is None: | ||
self._server_process = self.start_server() | ||
while not self.socket_path.exists(): | ||
await asyncio.sleep(0.1) | ||
socket = self.context.socket(zmq.DEALER) | ||
socket.connect(f"ipc://{self.socket_path}") | ||
try: | ||
yield socket | ||
finally: | ||
with suppress(Exception): | ||
socket.close() | ||
|
||
def cleanup(self): | ||
# delete socket file on exit | ||
self.socket_path.unlink(missing_ok=True) | ||
|
||
|
||
class EngineServer: | ||
|
||
CMDS = {} | ||
|
||
def __init__(self, socket_path): | ||
self.log = logging.getLogger(f"bbot.core.{self.__class__.__name__.lower()}") | ||
self.name = f"EngineServer {self.__class__.__name__}" | ||
if socket_path is not None: | ||
# create ZeroMQ context | ||
self.context = zmq.asyncio.Context() | ||
# ROUTER socket can handle multiple concurrent requests | ||
self.socket = self.context.socket(zmq.ROUTER) | ||
# create socket file | ||
self.socket.bind(f"ipc://{socket_path}") | ||
|
||
async def run_and_return(self, client_id, command_fn, **kwargs): | ||
self.log.debug(f"{self.name} run-and-return {command_fn.__name__}({kwargs})") | ||
try: | ||
result = await command_fn(**kwargs) | ||
except Exception as e: | ||
error = f"Unhandled error in {self.name}.{command_fn.__name__}({kwargs}): {e}" | ||
trace = traceback.format_exc() | ||
result = {"_e": (error, trace)} | ||
await self.send_socket_multipart([client_id, pickle.dumps(result)]) | ||
|
||
async def run_and_yield(self, client_id, command_fn, **kwargs): | ||
self.log.debug(f"{self.name} run-and-yield {command_fn.__name__}({kwargs})") | ||
try: | ||
async for _ in command_fn(**kwargs): | ||
await self.send_socket_multipart([client_id, pickle.dumps(_)]) | ||
await self.send_socket_multipart([client_id, pickle.dumps({"_s": None})]) | ||
except Exception as e: | ||
error = f"Unhandled error in {self.name}.{command_fn.__name__}({kwargs}): {e}" | ||
trace = traceback.format_exc() | ||
result = {"_e": (error, trace)} | ||
await self.send_socket_multipart([client_id, pickle.dumps(result)]) | ||
|
||
async def send_socket_multipart(self, *args, **kwargs): | ||
try: | ||
await self.socket.send_multipart(*args, **kwargs) | ||
except Exception as e: | ||
self.log.warning(f"Error sending ZMQ message: {e}") | ||
self.log.trace(traceback.format_exc()) | ||
|
||
async def worker(self): | ||
try: | ||
while 1: | ||
client_id, binary = await self.socket.recv_multipart() | ||
message = pickle.loads(binary) | ||
self.log.debug(f"{self.name} got message: {message}") | ||
|
||
cmd = message.get("c", None) | ||
if not isinstance(cmd, int): | ||
self.log.warning(f"No command sent in message: {message}") | ||
continue | ||
|
||
kwargs = message.get("a", {}) | ||
if not isinstance(kwargs, dict): | ||
self.log.warning(f"{self.name}: received invalid message of type {type(kwargs)}, should be dict") | ||
continue | ||
|
||
command_name = self.CMDS[cmd] | ||
command_fn = getattr(self, command_name, None) | ||
|
||
if command_fn is None: | ||
self.log.warning(f'{self.name} has no function named "{command_fn}"') | ||
continue | ||
|
||
if inspect.isasyncgenfunction(command_fn): | ||
coroutine = self.run_and_yield(client_id, command_fn, **kwargs) | ||
else: | ||
coroutine = self.run_and_return(client_id, command_fn, **kwargs) | ||
|
||
asyncio.create_task(coroutine) | ||
except Exception as e: | ||
self.log.error(f"Error in EngineServer worker: {e}") | ||
self.log.trace(traceback.format_exc()) | ||
finally: | ||
with suppress(Exception): | ||
self.socket.close() |
Oops, something went wrong.