Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update and Optimize Neo4j #1582

Merged
merged 16 commits into from
Jul 30, 2024
74 changes: 41 additions & 33 deletions bbot/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from bbot.core import CORE
from bbot.errors import BBOTEngineError
from bbot.core.helpers.misc import rand_string
from bbot.core.helpers.async_helpers import get_event_loop
from bbot.core.helpers.misc import rand_string, in_exception_chain


error_sentinel = object()
Expand All @@ -41,6 +41,7 @@ class EngineBase:
ERROR_CLASS = BBOTEngineError

def __init__(self):
self._shutdown_status = False
self.log = logging.getLogger(f"bbot.core.{self.__class__.__name__.lower()}")

def pickle(self, obj):
Expand All @@ -62,7 +63,7 @@ def unpickle(self, binary):

async def _infinite_retry(self, callback, *args, **kwargs):
interval = kwargs.pop("_interval", 10)
while 1:
while not self._shutdown_status:
try:
return await asyncio.wait_for(callback(*args, **kwargs), timeout=interval)
except (TimeoutError, asyncio.TimeoutError):
Expand Down Expand Up @@ -107,7 +108,6 @@ class EngineClient(EngineBase):
SERVER_CLASS = None

def __init__(self, **kwargs):
self._shutdown = False
super().__init__()
self.name = f"EngineClient {self.__class__.__name__}"
self.process = None
Expand Down Expand Up @@ -135,7 +135,7 @@ def check_error(self, message):
async def run_and_return(self, command, *args, **kwargs):
fn_str = f"{command}({args}, {kwargs})"
self.log.debug(f"{self.name}: executing run-and-return {fn_str}")
if self._shutdown and not command == "_shutdown":
if self._shutdown_status and not command == "_shutdown":
self.log.verbose(f"{self.name} has been shut down and is not accepting new tasks")
return
async with self.new_socket() as socket:
Expand Down Expand Up @@ -163,7 +163,7 @@ async def run_and_return(self, command, *args, **kwargs):
async def run_and_yield(self, command, *args, **kwargs):
fn_str = f"{command}({args}, {kwargs})"
self.log.debug(f"{self.name}: executing run-and-yield {fn_str}")
if self._shutdown:
if self._shutdown_status:
self.log.verbose("Engine has been shut down and is not accepting new tasks")
return
message = self.make_message(command, args=args, kwargs=kwargs)
Expand Down Expand Up @@ -213,14 +213,16 @@ async def send_shutdown_message(self):
async with self.new_socket() as socket:
# -99 == special shutdown message
message = pickle.dumps({"c": -99})
await self._infinite_retry(socket.send, message)
while 1:
response = await self._infinite_retry(socket.recv)
response = pickle.loads(response)
if isinstance(response, dict):
response = response.get("m", "")
if response == "SHUTDOWN_OK":
break
with suppress(TimeoutError, asyncio.TimeoutError):
await asyncio.wait_for(socket.send(message), 0.5)
with suppress(TimeoutError, asyncio.TimeoutError):
while 1:
response = await asyncio.wait_for(socket.recv(), 0.5)
response = pickle.loads(response)
if isinstance(response, dict):
response = response.get("m", "")
if response == "SHUTDOWN_OK":
break

def check_stop(self, message):
if isinstance(message, dict) and len(message) == 1 and "_s" in message:
Expand Down Expand Up @@ -280,7 +282,7 @@ def server_process(server_class, socket_path, **kwargs):
else:
asyncio.run(engine_server.worker())
except (asyncio.CancelledError, KeyboardInterrupt, CancelledError):
pass
return
except Exception:
import traceback

Expand All @@ -306,9 +308,9 @@ async def new_socket(self):
socket.close()

async def shutdown(self):
self.log.debug(f"{self.name}: shutting down...")
if not self._shutdown:
self._shutdown = True
if not self._shutdown_status:
self._shutdown_status = True
self.log.verbose(f"{self.name}: shutting down...")
# send shutdown signal
await self.send_shutdown_message()
# then terminate context
Expand Down Expand Up @@ -446,6 +448,7 @@ def check_error(self, message):
return True

async def worker(self):
self.log.debug(f"{self.name}: starting worker")
try:
while 1:
client_id, binary = await self.socket.recv_multipart()
Expand All @@ -462,8 +465,8 @@ async def worker(self):
# -1 == cancel task
if cmd == -1:
self.log.debug(f"{self.name} got cancel signal")
await self.cancel_task(client_id)
await self.send_socket_multipart(client_id, {"m": "CANCEL_OK"})
await self.cancel_task(client_id)
continue

# -99 == shutdown task
Expand Down Expand Up @@ -500,24 +503,28 @@ async def worker(self):
task = asyncio.create_task(coroutine)
self.tasks[client_id] = task, command_fn, args, kwargs
# self.log.debug(f"{self.name}: finished creating task for {command_name}() coroutine")
except Exception as e:
self.log.error(f"{self.name}: error in EngineServer worker: {e}")
self.log.trace(traceback.format_exc())
except BaseException as e:
await self._shutdown()
if not in_exception_chain(e, (KeyboardInterrupt, asyncio.CancelledError)):
self.log.error(f"{self.name}: error in EngineServer worker: {e}")
self.log.trace(traceback.format_exc())
finally:
self.log.debug(f"{self.name}: finished worker()")

async def _shutdown(self):
self.log.debug(f"{self.name}: shutting down...")
await self.cancel_all_tasks()
try:
self.context.destroy(linger=0)
except Exception:
self.log.trace(traceback.format_exc())
try:
self.context.term()
except Exception:
self.log.trace(traceback.format_exc())
self.log.debug(f"{self.name}: finished shutting down")
if not self._shutdown_status:
self.log.verbose(f"{self.name}: shutting down...")
self._shutdown_status = True
await self.cancel_all_tasks()
try:
self.context.destroy(linger=0)
except Exception:
self.log.trace(traceback.format_exc())
try:
self.context.term()
except Exception:
self.log.trace(traceback.format_exc())
self.log.debug(f"{self.name}: finished shutting down")

def new_child_task(self, client_id, coro):
task = asyncio.create_task(coro)
Expand Down Expand Up @@ -554,8 +561,9 @@ async def _cancel_task(self, task):
await asyncio.wait_for(task, timeout=10)
except (TimeoutError, asyncio.TimeoutError):
self.log.debug(f"{self.name}: Timeout cancelling task")
return
except (KeyboardInterrupt, asyncio.CancelledError):
pass
return
except BaseException as e:
self.log.error(f"Unhandled error in {task.get_coro().__name__}(): {e}")
self.log.trace(traceback.format_exc())
Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ async def _worker(self):
self.scan.stats.event_consumed(event, self)
self.debug(f"Intercepting {event}")
async with self.scan._acatch(context), self._task_counter.count(context):
forward_event = await self.handle_event(event, kwargs)
forward_event = await self.handle_event(event, **kwargs)
with suppress(ValueError, TypeError):
forward_event, forward_event_reason = forward_event

Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/internal/cloudcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def filter_event(self, event):
return False, "event does not have host attribute"
return True

async def handle_event(self, event, kwargs):
async def handle_event(self, event, **kwargs):
# don't hold up the event loop loading cloud IPs etc.
if self.dummy_modules is None:
self.make_dummy_modules()
Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/internal/dnsresolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def filter_event(self, event):
return False, "event does not have host attribute"
return True

async def handle_event(self, event, kwargs):
async def handle_event(self, event, **kwargs):
dns_tags = set()
dns_children = dict()
event_whitelisted = False
Expand Down
16 changes: 8 additions & 8 deletions bbot/modules/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def human_event_str(self, event):
return event_str

def _event_precheck(self, event):
reason = "precheck succeeded"
# special signal event types
if event.type in ("FINISHED",):
return True, "its type is FINISHED"
Expand All @@ -42,24 +43,23 @@ def _event_precheck(self, event):
if event.type.startswith("URL") and self.name != "httpx" and "httpx-only" in event.tags:
return False, (f"Omitting {event} from output because it's marked as httpx-only")

if event._omit:
return False, "_omit is True"

# omit certain event types
if event.type in self.scan.omitted_event_types:
if event._omit:
if "target" in event.tags:
self.debug(f"Allowing omitted event: {event} because it's a target")
reason = "it's a target"
self.debug(f"Allowing omitted event: {event} because {reason}")
elif event.type in self.get_watched_events():
self.debug(f"Allowing omitted event: {event} because its type is explicitly in watched_events")
reason = "its type is explicitly in watched_events"
self.debug(f"Allowing omitted event: {event} because {reason}")
else:
return False, "its type is omitted in the config"
return False, "_omit is True"

# internal events like those from speculate, ipneighbor
# or events that are over our report distance
if event._internal:
return False, "_internal is True"

return True, "precheck succeeded"
return True, reason

async def _event_postcheck(self, event):
acceptable, reason = await super()._event_postcheck(event)
Expand Down
110 changes: 86 additions & 24 deletions bbot/modules/output/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class neo4j(BaseOutputModule):
"password": "Neo4j password",
}
deps_pip = ["neo4j"]
_batch_size = 500
_preserve_graph = True

async def setup(self):
Expand All @@ -51,32 +52,93 @@ async def setup(self):
return False, f"Error setting up Neo4j: {e}"
return True

async def handle_event(self, event):
# create events
src_id = await self.merge_event(event.get_parent(), id_only=True)
dst_id = await self.merge_event(event)
# create relationship
cypher = f"""
MATCH (a) WHERE id(a) = $src_id
MATCH (b) WHERE id(b) = $dst_id
MERGE (a)-[_:{event.module}]->(b)
SET _.timestamp = $timestamp"""
await self.session.run(cypher, src_id=src_id, dst_id=dst_id, timestamp=event.timestamp)

async def merge_event(self, event, id_only=False):
async def handle_batch(self, *all_events):
await self.helpers.sleep(5)
# group events by type, since cypher doesn't allow dynamic labels
events_by_type = {}
parents_by_type = {}
relationships = []
for event in all_events:
parent = event.get_parent()
try:
events_by_type[event.type].append(event)
except KeyError:
events_by_type[event.type] = [event]
try:
parents_by_type[parent.type].append(parent)
except KeyError:
parents_by_type[parent.type] = [parent]

module = str(event.module)
timestamp = event.timestamp
relationships.append((parent, module, timestamp, event))

all_ids = {}
for event_type, events in events_by_type.items():
self.debug(f"{len(events):,} events of type {event_type}")
all_ids.update(await self.merge_events(events, event_type))
for event_type, parents in parents_by_type.items():
self.debug(f"{len(parents):,} parents of type {event_type}")
all_ids.update(await self.merge_events(parents, event_type, id_only=True))

rel_ids = []
for parent, module, timestamp, event in relationships:
try:
src_id = all_ids[parent.id]
dst_id = all_ids[event.id]
except KeyError as e:
self.critical(f'Error "{e}" correlating {parent.id}:{parent.data} --> {event.id}:{event.data}')
continue
rel_ids.append((src_id, module, timestamp, dst_id))

await self.merge_relationships(rel_ids)

async def merge_events(self, events, event_type, id_only=False):
if id_only:
eventdata = {"type": event.type, "id": event.id}
insert_data = [{"data": str(e.data), "type": e.type, "id": e.id} for e in events]
else:
eventdata = event.json(mode="graph")
# we pop the timestamp because it belongs on the relationship
eventdata.pop("timestamp")
cypher = f"""MERGE (_:{event.type} {{ id: $eventdata['id'] }})
SET _ += $eventdata
RETURN id(_)"""
# insert event
result = await self.session.run(cypher, eventdata=eventdata)
# get Neo4j id
return (await result.single()).get("id(_)")
insert_data = []
for e in events:
event_json = e.json(mode="graph")
# we pop the timestamp because it belongs on the relationship
event_json.pop("timestamp")
# nested data types aren't supported in neo4j
event_json.pop("dns_children", None)
insert_data.append(event_json)

cypher = f"""UNWIND $events AS event
MERGE (_:{event_type} {{ id: event.id }})
SET _ += event
RETURN event.data as event_data, event.id as event_id, elementId(_) as neo4j_id"""
# insert events
results = await self.session.run(cypher, events=insert_data)
# get Neo4j ids
neo4j_ids = {}
for result in await results.data():
event_id = result["event_id"]
neo4j_id = result["neo4j_id"]
neo4j_ids[event_id] = neo4j_id
return neo4j_ids

async def merge_relationships(self, relationships):
rels_by_module = {}
# group by module
for src_id, module, timestamp, dst_id in relationships:
data = {"src_id": src_id, "timestamp": timestamp, "dst_id": dst_id}
try:
rels_by_module[module].append(data)
except KeyError:
rels_by_module[module] = [data]

for module, rels in rels_by_module.items():
self.debug(f"{len(rels):,} relationships of type {module}")
cypher = f"""
UNWIND $rels AS rel
MATCH (a) WHERE elementId(a) = rel.src_id
MATCH (b) WHERE elementId(b) = rel.dst_id
MERGE (a)-[_:{module}]->(b)
SET _.timestamp = rel.timestamp"""
await self.session.run(cypher, rels=rels)

async def cleanup(self):
with suppress(Exception):
Expand Down
Loading
Loading