From 9ccef3981182e58685349eb629e275a58cd7a323 Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Wed, 24 Jul 2024 13:28:48 -0400 Subject: [PATCH 01/13] neo4j update --- bbot/modules/output/neo4j.py | 110 +++++++++++++++++++++++++++-------- 1 file changed, 86 insertions(+), 24 deletions(-) diff --git a/bbot/modules/output/neo4j.py b/bbot/modules/output/neo4j.py index 87220d26d..0fd6477d1 100644 --- a/bbot/modules/output/neo4j.py +++ b/bbot/modules/output/neo4j.py @@ -34,6 +34,7 @@ class neo4j(BaseOutputModule): "password": "Neo4j password", } deps_pip = ["neo4j"] + _batch_size = 500 _preserve_graph = True async def setup(self): @@ -51,32 +52,93 @@ async def setup(self): return False, f"Error setting up Neo4j: {e}" return True - async def handle_event(self, event): - # create events - src_id = await self.merge_event(event.get_parent(), id_only=True) - dst_id = await self.merge_event(event) - # create relationship - cypher = f""" - MATCH (a) WHERE id(a) = $src_id - MATCH (b) WHERE id(b) = $dst_id - MERGE (a)-[_:{event.module}]->(b) - SET _.timestamp = $timestamp""" - await self.session.run(cypher, src_id=src_id, dst_id=dst_id, timestamp=event.timestamp) - - async def merge_event(self, event, id_only=False): + async def handle_batch(self, *all_events): + await self.helpers.sleep(5) + # group events by type, since cypher doesn't allow dynamic labels + events_by_type = {} + parents_by_type = {} + relationships = [] + for event in all_events: + parent = event.get_parent() + try: + events_by_type[event.type].append(event) + except KeyError: + events_by_type[event.type] = [event] + try: + parents_by_type[parent.type].append(parent) + except KeyError: + parents_by_type[parent.type] = [parent] + + module = str(event.module) + timestamp = event.timestamp + relationships.append((parent, module, timestamp, event)) + + all_ids = {} + for event_type, events in events_by_type.items(): + self.debug(f"{len(events):,} events of type {event_type}") + all_ids.update(await self.merge_events(events, event_type)) + for event_type, parents in parents_by_type.items(): + self.debug(f"{len(parents):,} parents of type {event_type}") + all_ids.update(await self.merge_events(parents, event_type, id_only=True)) + + rel_ids = [] + for parent, module, timestamp, event in relationships: + try: + src_id = all_ids[parent.id] + dst_id = all_ids[event.id] + except KeyError as e: + self.critical(f'Error "{e}" correlating {parent.id}:{parent.data} --> {event.id}:{event.data}') + continue + rel_ids.append((src_id, module, timestamp, dst_id)) + + await self.merge_relationships(rel_ids) + + async def merge_events(self, events, event_type, id_only=False): if id_only: - eventdata = {"type": event.type, "id": event.id} + insert_data = [{"data": str(e.data), "type": e.type, "id": e.id} for e in events] else: - eventdata = event.json(mode="graph") - # we pop the timestamp because it belongs on the relationship - eventdata.pop("timestamp") - cypher = f"""MERGE (_:{event.type} {{ id: $eventdata['id'] }}) - SET _ += $eventdata - RETURN id(_)""" - # insert event - result = await self.session.run(cypher, eventdata=eventdata) - # get Neo4j id - return (await result.single()).get("id(_)") + insert_data = [] + for e in events: + event_json = e.json(mode="graph") + # we pop the timestamp because it belongs on the relationship + event_json.pop("timestamp") + # nested data types aren't supported in neo4j + event_json.pop("dns_children", None) + insert_data.append(event_json) + + cypher = f"""UNWIND $events AS event + MERGE (_:{event_type} {{ id: event.id }}) + SET _ += event + RETURN event.data as event_data, event.id as event_id, elementId(_) as neo4j_id""" + # insert events + results = await self.session.run(cypher, events=insert_data) + # get Neo4j ids + neo4j_ids = {} + for result in await results.data(): + event_id = result["event_id"] + neo4j_id = result["neo4j_id"] + neo4j_ids[event_id] = neo4j_id + return neo4j_ids + + async def merge_relationships(self, relationships): + rels_by_module = {} + # group by module + for src_id, module, timestamp, dst_id in relationships: + data = {"src_id": src_id, "timestamp": timestamp, "dst_id": dst_id} + try: + rels_by_module[module].append(data) + except KeyError: + rels_by_module[module] = [data] + + for module, rels in rels_by_module.items(): + self.debug(f"{len(rels):,} relationships of type {module}") + cypher = f""" + UNWIND $rels AS rel + MATCH (a) WHERE elementId(a) = rel.src_id + MATCH (b) WHERE elementId(b) = rel.dst_id + MERGE (a)-[_:{module}]->(b) + SET _.timestamp = rel.timestamp""" + await self.session.run(cypher, rels=rels) async def cleanup(self): with suppress(Exception): From 384453685c37c4f5064a2d2a9b8172e97066204e Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Thu, 25 Jul 2024 17:01:17 -0400 Subject: [PATCH 02/13] internal --> _omit --- bbot/scanner/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bbot/scanner/manager.py b/bbot/scanner/manager.py index 931c51b00..58199177c 100644 --- a/bbot/scanner/manager.py +++ b/bbot/scanner/manager.py @@ -196,7 +196,7 @@ async def handle_event(self, event, kwargs): if "target" in event.tags: self.debug(f"Allowing omitted event: {event} because it's a target") else: - event.internal = True + event._omit = True # make event internal if it's above our configured report distance event_in_report_distance = event.scope_distance <= self.scan.scope_report_distance From b08d6bebd82709966d49eb1aaad38a620e03c131 Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sat, 27 Jul 2024 22:07:17 -0400 Subject: [PATCH 03/13] reorganize _omit --- bbot/modules/output/base.py | 5 +---- bbot/scanner/manager.py | 4 ++++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/bbot/modules/output/base.py b/bbot/modules/output/base.py index 8a6eba9eb..b9e15f720 100644 --- a/bbot/modules/output/base.py +++ b/bbot/modules/output/base.py @@ -42,11 +42,8 @@ def _event_precheck(self, event): if event.type.startswith("URL") and self.name != "httpx" and "httpx-only" in event.tags: return False, (f"Omitting {event} from output because it's marked as httpx-only") - if event._omit: - return False, "_omit is True" - # omit certain event types - if event.type in self.scan.omitted_event_types: + if event._omit: if "target" in event.tags: self.debug(f"Allowing omitted event: {event} because it's a target") elif event.type in self.get_watched_events(): diff --git a/bbot/scanner/manager.py b/bbot/scanner/manager.py index 58199177c..978f79784 100644 --- a/bbot/scanner/manager.py +++ b/bbot/scanner/manager.py @@ -208,6 +208,10 @@ async def handle_event(self, event, kwargs): ) event.internal = True + if event.type in self.scan.omitted_event_types: + log.debug(f"Omitting {event} because its type is omitted in the config") + event._omit = True + # if we discovered something interesting from an internal event, # make sure we preserve its chain of parents parent = event.parent From 21301e8014d55a0efe71da52214f64096f9b0b91 Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sat, 27 Jul 2024 23:04:15 -0400 Subject: [PATCH 04/13] update _omit tests --- bbot/modules/base.py | 2 +- bbot/modules/output/base.py | 11 ++++--- bbot/scanner/manager.py | 4 +-- bbot/test/test_step_1/test_modules_basic.py | 36 ++++++++++++++------- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/bbot/modules/base.py b/bbot/modules/base.py index 52e23044f..9b43b1d2f 100644 --- a/bbot/modules/base.py +++ b/bbot/modules/base.py @@ -1478,7 +1478,7 @@ async def _worker(self): self.scan.stats.event_consumed(event, self) self.debug(f"Intercepting {event}") async with self.scan._acatch(context), self._task_counter.count(context): - forward_event = await self.handle_event(event, kwargs) + forward_event = await self.handle_event(event, **kwargs) with suppress(ValueError, TypeError): forward_event, forward_event_reason = forward_event diff --git a/bbot/modules/output/base.py b/bbot/modules/output/base.py index b9e15f720..0f6e7ac78 100644 --- a/bbot/modules/output/base.py +++ b/bbot/modules/output/base.py @@ -19,6 +19,7 @@ def human_event_str(self, event): return event_str def _event_precheck(self, event): + reason = "precheck succeeded" # special signal event types if event.type in ("FINISHED",): return True, "its type is FINISHED" @@ -45,18 +46,20 @@ def _event_precheck(self, event): # omit certain event types if event._omit: if "target" in event.tags: - self.debug(f"Allowing omitted event: {event} because it's a target") + reason = "it's a target" + self.debug(f"Allowing omitted event: {event} because {reason}") elif event.type in self.get_watched_events(): - self.debug(f"Allowing omitted event: {event} because its type is explicitly in watched_events") + reason = "its type is explicitly in watched_events" + self.debug(f"Allowing omitted event: {event} because {reason}") else: - return False, "its type is omitted in the config" + return False, "_omit is True" # internal events like those from speculate, ipneighbor # or events that are over our report distance if event._internal: return False, "_internal is True" - return True, "precheck succeeded" + return True, reason async def _event_postcheck(self, event): acceptable, reason = await super()._event_postcheck(event) diff --git a/bbot/scanner/manager.py b/bbot/scanner/manager.py index 978f79784..074d6f6a1 100644 --- a/bbot/scanner/manager.py +++ b/bbot/scanner/manager.py @@ -62,7 +62,7 @@ async def init_events(self, events=None): await asyncio.sleep(0.1) self.scan._finished_init = True - async def handle_event(self, event, kwargs): + async def handle_event(self, event, **kwargs): # don't accept dummy events if event._dummy: return False, "cannot emit dummy event" @@ -187,7 +187,7 @@ def priority(self): # we are the lowest priority return 99 - async def handle_event(self, event, kwargs): + async def handle_event(self, event, **kwargs): abort_if = kwargs.pop("abort_if", None) on_success_callback = kwargs.pop("on_success_callback", None) diff --git a/bbot/test/test_step_1/test_modules_basic.py b/bbot/test/test_step_1/test_modules_basic.py index 87657e05f..b3b30f2ae 100644 --- a/bbot/test/test_step_1/test_modules_basic.py +++ b/bbot/test/test_step_1/test_modules_basic.py @@ -18,6 +18,8 @@ async def test_modules_basic_checks(events, httpx_mock): scan = Scanner(config={"omit_event_types": ["URL_UNVERIFIED"]}) assert "URL_UNVERIFIED" in scan.omitted_event_types + await scan.load_modules() + # output module specific event filtering tests base_output_module_1 = BaseOutputModule(scan) base_output_module_1.watched_events = ["IP_ADDRESS", "URL_UNVERIFIED"] @@ -35,21 +37,17 @@ async def test_modules_basic_checks(events, httpx_mock): result, reason = base_output_module_1._event_precheck(localhost) assert result == True assert reason == "precheck succeeded" - # omitted events should be rejected - localhost._omit = True - result, reason = base_output_module_1._event_precheck(localhost) - assert result == False - assert reason == "_omit is True" - # unwatched event types should be rejected - dns_name = scan.make_event("evilcorp.com", "DNS_NAME", parent=scan.root_event) + # unwatched events should be rejected + dns_name = scan.make_event("evilcorp.com", parent=scan.root_event) result, reason = base_output_module_1._event_precheck(dns_name) assert result == False assert reason == "its type is not in watched_events" - # omitted event types matching watched events should be accepted + # omitted events matching watched types should be accepted url_unverified = scan.make_event("http://127.0.0.1", "URL_UNVERIFIED", parent=scan.root_event) + url_unverified._omit = True result, reason = base_output_module_1._event_precheck(url_unverified) assert result == True - assert reason == "precheck succeeded" + assert reason == "its type is explicitly in watched_events" base_output_module_2 = BaseOutputModule(scan) base_output_module_2.watched_events = ["*"] @@ -72,11 +70,27 @@ async def test_modules_basic_checks(events, httpx_mock): result, reason = base_output_module_2._event_precheck(localhost) assert result == False assert reason == "_omit is True" - # omitted event types should be rejected + # normal event should be accepted url_unverified = scan.make_event("http://127.0.0.1", "URL_UNVERIFIED", parent=scan.root_event) result, reason = base_output_module_2._event_precheck(url_unverified) + assert result == True + assert reason == "precheck succeeded" + # omitted event types should be marked during scan egress + await scan.egress_module.handle_event(url_unverified) + result, reason = base_output_module_2._event_precheck(url_unverified) + assert result == False + assert reason == "_omit is True" + # omitted events that are targets should be accepted + dns_name = scan.make_event("evilcorp.com", "DNS_NAME", parent=scan.root_event) + dns_name._omit = True + result, reason = base_output_module_2._event_precheck(dns_name) assert result == False - assert reason == "its type is omitted in the config" + assert reason == "_omit is True" + # omitted results that are targets should be accepted + dns_name.add_tag("target") + result, reason = base_output_module_2._event_precheck(dns_name) + assert result == True + assert reason == "it's a target" # common event filtering tests for module_class in (BaseModule, BaseOutputModule, BaseReportModule, BaseInternalModule): From 8e5ec9e08df0c7ee00ac89dbb7d86effbee52c2c Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sat, 27 Jul 2024 23:18:52 -0400 Subject: [PATCH 05/13] update kwargs --- bbot/modules/internal/cloudcheck.py | 2 +- bbot/modules/internal/dnsresolve.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bbot/modules/internal/cloudcheck.py b/bbot/modules/internal/cloudcheck.py index 45011c509..15d9bf364 100644 --- a/bbot/modules/internal/cloudcheck.py +++ b/bbot/modules/internal/cloudcheck.py @@ -21,7 +21,7 @@ async def filter_event(self, event): return False, "event does not have host attribute" return True - async def handle_event(self, event, kwargs): + async def handle_event(self, event, **kwargs): # don't hold up the event loop loading cloud IPs etc. if self.dummy_modules is None: self.make_dummy_modules() diff --git a/bbot/modules/internal/dnsresolve.py b/bbot/modules/internal/dnsresolve.py index 45307d50d..6efe4ff1f 100644 --- a/bbot/modules/internal/dnsresolve.py +++ b/bbot/modules/internal/dnsresolve.py @@ -63,7 +63,7 @@ async def filter_event(self, event): return False, "event does not have host attribute" return True - async def handle_event(self, event, kwargs): + async def handle_event(self, event, **kwargs): dns_tags = set() dns_children = dict() event_whitelisted = False From e7cc6b78d6e9f9ce7b5d71c5dc40a94a84cf3a26 Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sun, 28 Jul 2024 00:22:24 -0400 Subject: [PATCH 06/13] fix handle_event in tests --- bbot/test/test_step_1/test_dns.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bbot/test/test_step_1/test_dns.py b/bbot/test/test_step_1/test_dns.py index 0b7e03362..e38db4fe5 100644 --- a/bbot/test/test_step_1/test_dns.py +++ b/bbot/test/test_step_1/test_dns.py @@ -150,10 +150,10 @@ async def test_dns_resolution(bbot_scanner): dnsresolve = scan.modules["dnsresolve"] assert hash(resolved_hosts_event1.host) not in dnsresolve._event_cache assert hash(resolved_hosts_event2.host) not in dnsresolve._event_cache - await dnsresolve.handle_event(resolved_hosts_event1, {}) + await dnsresolve.handle_event(resolved_hosts_event1) assert hash(resolved_hosts_event1.host) in dnsresolve._event_cache assert hash(resolved_hosts_event2.host) in dnsresolve._event_cache - await dnsresolve.handle_event(resolved_hosts_event2, {}) + await dnsresolve.handle_event(resolved_hosts_event2) assert "1.1.1.1" in resolved_hosts_event2.resolved_hosts assert "1.1.1.1" in resolved_hosts_event2.dns_children["A"] assert resolved_hosts_event1.resolved_hosts == resolved_hosts_event2.resolved_hosts @@ -222,9 +222,9 @@ async def test_wildcards(bbot_scanner): # event resolution await scan._prep() dnsresolve = scan.modules["dnsresolve"] - await dnsresolve.handle_event(wildcard_event1, {}) - await dnsresolve.handle_event(wildcard_event2, {}) - await dnsresolve.handle_event(wildcard_event3, {}) + await dnsresolve.handle_event(wildcard_event1) + await dnsresolve.handle_event(wildcard_event2) + await dnsresolve.handle_event(wildcard_event3) assert "wildcard" in wildcard_event1.tags assert "a-wildcard" in wildcard_event1.tags assert "srv-wildcard" not in wildcard_event1.tags From c9ee0e9262a2e92e0d0e187064c294f461ee8cf4 Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sun, 28 Jul 2024 02:52:30 -0400 Subject: [PATCH 07/13] fixing tests --- bbot/test/test_step_1/test_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bbot/test/test_step_1/test_events.py b/bbot/test/test_step_1/test_events.py index 7286d74b7..9d070b781 100644 --- a/bbot/test/test_step_1/test_events.py +++ b/bbot/test/test_step_1/test_events.py @@ -166,7 +166,7 @@ async def test_events(events, helpers): javascript_event = scan.make_event("http://evilcorp.com/asdf/a.js?b=c#d", "URL_UNVERIFIED", parent=scan.root_event) assert "extension-js" in javascript_event.tags - await scan.ingress_module.handle_event(javascript_event, {}) + await scan.ingress_module.handle_event(javascript_event) assert "httpx-only" in javascript_event.tags # scope distance From b855d7019b68c1772e850a960e352164a49ac008 Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sun, 28 Jul 2024 03:06:55 -0400 Subject: [PATCH 08/13] more test fixing --- bbot/test/test_step_1/test_scan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bbot/test/test_step_1/test_scan.py b/bbot/test/test_step_1/test_scan.py index 058b625d7..d907025c0 100644 --- a/bbot/test/test_step_1/test_scan.py +++ b/bbot/test/test_step_1/test_scan.py @@ -90,13 +90,13 @@ async def test_url_extension_handling(bbot_scanner): httpx_event = scan.make_event("https://evilcorp.com/a.js", "URL", tags=["status-200"], parent=scan.root_event) assert "blacklisted" not in bad_event.tags assert "httpx-only" not in httpx_event.tags - result = await scan.ingress_module.handle_event(good_event, {}) + result = await scan.ingress_module.handle_event(good_event) assert result == None - result, reason = await scan.ingress_module.handle_event(bad_event, {}) + result, reason = await scan.ingress_module.handle_event(bad_event) assert result == False assert reason == "event is blacklisted" assert "blacklisted" in bad_event.tags - result = await scan.ingress_module.handle_event(httpx_event, {}) + result = await scan.ingress_module.handle_event(httpx_event) assert result == None assert "httpx-only" in httpx_event.tags From e9ebdf4da38ec8c90b3fb5180b0f8078e84b4e7c Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sun, 28 Jul 2024 12:24:03 -0400 Subject: [PATCH 09/13] fix cloudcheck tests --- .../test/test_step_2/module_tests/test_module_cloudcheck.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bbot/test/test_step_2/module_tests/test_module_cloudcheck.py b/bbot/test/test_step_2/module_tests/test_module_cloudcheck.py index 6438b03e1..b95e7455d 100644 --- a/bbot/test/test_step_2/module_tests/test_module_cloudcheck.py +++ b/bbot/test/test_step_2/module_tests/test_module_cloudcheck.py @@ -48,11 +48,11 @@ async def setup_after_prep(self, module_test): other_event3._resolved_hosts = {"asdf.amazonaws.com"} for event in (ip_event, aws_event1, aws_event2, aws_event4, other_event2, other_event3): - await module.handle_event(event, {}) + await module.handle_event(event) assert "cloud-amazon" in event.tags, f"{event} was not properly cloud-tagged" for event in (aws_event3, other_event1): - await module.handle_event(event, {}) + await module.handle_event(event) assert "cloud-amazon" not in event.tags, f"{event} was improperly cloud-tagged" assert not any( t for t in event.tags if t.startswith("cloud-") or t.startswith("cdn-") @@ -64,7 +64,7 @@ async def setup_after_prep(self, module_test): google_event3._resolved_hosts = {"asdf.storage.googleapis.com"} for event in (google_event1, google_event2, google_event3): - await module.handle_event(event, {}) + await module.handle_event(event) assert "cloud-google" in event.tags, f"{event} was not properly cloud-tagged" assert "cloud-storage-bucket" in google_event3.tags From ec7e664d8e89568ed1b0d52b72e8e76203fd9e9f Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sun, 28 Jul 2024 14:26:17 -0400 Subject: [PATCH 10/13] fix neo4j tests --- bbot/test/test_step_2/module_tests/test_module_neo4j.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bbot/test/test_step_2/module_tests/test_module_neo4j.py b/bbot/test/test_step_2/module_tests/test_module_neo4j.py index fcb21b94b..9ff63f9a5 100644 --- a/bbot/test/test_step_2/module_tests/test_module_neo4j.py +++ b/bbot/test/test_step_2/module_tests/test_module_neo4j.py @@ -10,9 +10,12 @@ async def setup_before_prep(self, module_test): self.neo4j_used = False class MockResult: - async def single(s): + async def data(s): self.neo4j_used = True - return {"id(_)": 1} + return [{ + "neo4j_id": "4:ee79a477-5f5b-445a-9def-7c051b2a533c:115", + "event_id": "DNS_NAME:c8fab50640cb87f8712d1998ecc78caf92b90f71", + }] class MockSession: async def run(s, *args, **kwargs): From 8e45cb30a7fa4829b1a1a96d769214ee30975570 Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sun, 28 Jul 2024 14:49:27 -0400 Subject: [PATCH 11/13] blacked --- .../test/test_step_2/module_tests/test_module_neo4j.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bbot/test/test_step_2/module_tests/test_module_neo4j.py b/bbot/test/test_step_2/module_tests/test_module_neo4j.py index 9ff63f9a5..9db35cff7 100644 --- a/bbot/test/test_step_2/module_tests/test_module_neo4j.py +++ b/bbot/test/test_step_2/module_tests/test_module_neo4j.py @@ -12,10 +12,12 @@ async def setup_before_prep(self, module_test): class MockResult: async def data(s): self.neo4j_used = True - return [{ - "neo4j_id": "4:ee79a477-5f5b-445a-9def-7c051b2a533c:115", - "event_id": "DNS_NAME:c8fab50640cb87f8712d1998ecc78caf92b90f71", - }] + return [ + { + "neo4j_id": "4:ee79a477-5f5b-445a-9def-7c051b2a533c:115", + "event_id": "DNS_NAME:c8fab50640cb87f8712d1998ecc78caf92b90f71", + } + ] class MockSession: async def run(s, *args, **kwargs): From 2c04d806fc284e1be7404cf3b0ed7356cd0f1ee6 Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sun, 28 Jul 2024 18:14:11 -0400 Subject: [PATCH 12/13] ctrl+c improvement --- bbot/core/engine.py | 74 +++++++++++++++++++++++------------------ bbot/scanner/scanner.py | 4 ++- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/bbot/core/engine.py b/bbot/core/engine.py index 20ef59a4a..c3897cbef 100644 --- a/bbot/core/engine.py +++ b/bbot/core/engine.py @@ -16,8 +16,8 @@ from bbot.core import CORE from bbot.errors import BBOTEngineError -from bbot.core.helpers.misc import rand_string from bbot.core.helpers.async_helpers import get_event_loop +from bbot.core.helpers.misc import rand_string, in_exception_chain error_sentinel = object() @@ -41,6 +41,7 @@ class EngineBase: ERROR_CLASS = BBOTEngineError def __init__(self): + self._shutdown_status = False self.log = logging.getLogger(f"bbot.core.{self.__class__.__name__.lower()}") def pickle(self, obj): @@ -62,7 +63,7 @@ def unpickle(self, binary): async def _infinite_retry(self, callback, *args, **kwargs): interval = kwargs.pop("_interval", 10) - while 1: + while not self._shutdown_status: try: return await asyncio.wait_for(callback(*args, **kwargs), timeout=interval) except (TimeoutError, asyncio.TimeoutError): @@ -107,7 +108,6 @@ class EngineClient(EngineBase): SERVER_CLASS = None def __init__(self, **kwargs): - self._shutdown = False super().__init__() self.name = f"EngineClient {self.__class__.__name__}" self.process = None @@ -135,7 +135,7 @@ def check_error(self, message): async def run_and_return(self, command, *args, **kwargs): fn_str = f"{command}({args}, {kwargs})" self.log.debug(f"{self.name}: executing run-and-return {fn_str}") - if self._shutdown and not command == "_shutdown": + if self._shutdown_status and not command == "_shutdown": self.log.verbose(f"{self.name} has been shut down and is not accepting new tasks") return async with self.new_socket() as socket: @@ -163,7 +163,7 @@ async def run_and_return(self, command, *args, **kwargs): async def run_and_yield(self, command, *args, **kwargs): fn_str = f"{command}({args}, {kwargs})" self.log.debug(f"{self.name}: executing run-and-yield {fn_str}") - if self._shutdown: + if self._shutdown_status: self.log.verbose("Engine has been shut down and is not accepting new tasks") return message = self.make_message(command, args=args, kwargs=kwargs) @@ -213,14 +213,16 @@ async def send_shutdown_message(self): async with self.new_socket() as socket: # -99 == special shutdown message message = pickle.dumps({"c": -99}) - await self._infinite_retry(socket.send, message) - while 1: - response = await self._infinite_retry(socket.recv) - response = pickle.loads(response) - if isinstance(response, dict): - response = response.get("m", "") - if response == "SHUTDOWN_OK": - break + with suppress(TimeoutError, asyncio.TimeoutError): + await asyncio.wait_for(socket.send(message), 0.5) + with suppress(TimeoutError, asyncio.TimeoutError): + while 1: + response = await asyncio.wait_for(socket.recv(), 0.5) + response = pickle.loads(response) + if isinstance(response, dict): + response = response.get("m", "") + if response == "SHUTDOWN_OK": + break def check_stop(self, message): if isinstance(message, dict) and len(message) == 1 and "_s" in message: @@ -280,7 +282,7 @@ def server_process(server_class, socket_path, **kwargs): else: asyncio.run(engine_server.worker()) except (asyncio.CancelledError, KeyboardInterrupt, CancelledError): - pass + return except Exception: import traceback @@ -306,9 +308,9 @@ async def new_socket(self): socket.close() async def shutdown(self): - self.log.debug(f"{self.name}: shutting down...") - if not self._shutdown: - self._shutdown = True + if not self._shutdown_status: + self._shutdown_status = True + self.log.hugewarning(f"{self.name}: shutting down...") # send shutdown signal await self.send_shutdown_message() # then terminate context @@ -446,6 +448,7 @@ def check_error(self, message): return True async def worker(self): + self.log.debug(f"{self.name}: starting worker") try: while 1: client_id, binary = await self.socket.recv_multipart() @@ -462,8 +465,8 @@ async def worker(self): # -1 == cancel task if cmd == -1: self.log.debug(f"{self.name} got cancel signal") - await self.cancel_task(client_id) await self.send_socket_multipart(client_id, {"m": "CANCEL_OK"}) + await self.cancel_task(client_id) continue # -99 == shutdown task @@ -500,24 +503,28 @@ async def worker(self): task = asyncio.create_task(coroutine) self.tasks[client_id] = task, command_fn, args, kwargs # self.log.debug(f"{self.name}: finished creating task for {command_name}() coroutine") - except Exception as e: - self.log.error(f"{self.name}: error in EngineServer worker: {e}") - self.log.trace(traceback.format_exc()) + except BaseException as e: + await self._shutdown() + if not in_exception_chain(e, (KeyboardInterrupt, asyncio.CancelledError)): + self.log.error(f"{self.name}: error in EngineServer worker: {e}") + self.log.trace(traceback.format_exc()) finally: self.log.debug(f"{self.name}: finished worker()") async def _shutdown(self): - self.log.debug(f"{self.name}: shutting down...") - await self.cancel_all_tasks() - try: - self.context.destroy(linger=0) - except Exception: - self.log.trace(traceback.format_exc()) - try: - self.context.term() - except Exception: - self.log.trace(traceback.format_exc()) - self.log.debug(f"{self.name}: finished shutting down") + if not self._shutdown_status: + self.log.critical(f"{self.name}: shutting down...") + self._shutdown_status = True + await self.cancel_all_tasks() + try: + self.context.destroy(linger=0) + except Exception: + self.log.trace(traceback.format_exc()) + try: + self.context.term() + except Exception: + self.log.trace(traceback.format_exc()) + self.log.debug(f"{self.name}: finished shutting down") def new_child_task(self, client_id, coro): task = asyncio.create_task(coro) @@ -554,8 +561,9 @@ async def _cancel_task(self, task): await asyncio.wait_for(task, timeout=10) except (TimeoutError, asyncio.TimeoutError): self.log.debug(f"{self.name}: Timeout cancelling task") + return except (KeyboardInterrupt, asyncio.CancelledError): - pass + return except BaseException as e: self.log.error(f"Unhandled error in {task.get_coro().__name__}(): {e}") self.log.trace(traceback.format_exc()) diff --git a/bbot/scanner/scanner.py b/bbot/scanner/scanner.py index d90b8c329..0fe4191bf 100644 --- a/bbot/scanner/scanner.py +++ b/bbot/scanner/scanner.py @@ -353,6 +353,8 @@ async def async_start(self): events, finish = await self.modules["python"]._events_waiting(batch_size=-1) for e in events: yield e + if events: + continue # break if initialization finished and the scan is no longer active if self._finished_init and self.modules_finished: @@ -386,7 +388,7 @@ async def async_start(self): for task in tasks: # self.debug(f"Awaiting {task}") with contextlib.suppress(BaseException): - await task + await asyncio.wait_for(task, timeout=0.1) self.debug(f"Awaited {len(tasks):,} tasks") await self._report() await self._cleanup() From ef368e2192883160bb0cd2fc1f23f3994772a74c Mon Sep 17 00:00:00 2001 From: TheTechromancer Date: Sun, 28 Jul 2024 18:31:09 -0400 Subject: [PATCH 13/13] prevent graph orphans --- bbot/core/engine.py | 4 ++-- bbot/scanner/manager.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/bbot/core/engine.py b/bbot/core/engine.py index c3897cbef..d8dd1af28 100644 --- a/bbot/core/engine.py +++ b/bbot/core/engine.py @@ -310,7 +310,7 @@ async def new_socket(self): async def shutdown(self): if not self._shutdown_status: self._shutdown_status = True - self.log.hugewarning(f"{self.name}: shutting down...") + self.log.verbose(f"{self.name}: shutting down...") # send shutdown signal await self.send_shutdown_message() # then terminate context @@ -513,7 +513,7 @@ async def worker(self): async def _shutdown(self): if not self._shutdown_status: - self.log.critical(f"{self.name}: shutting down...") + self.log.verbose(f"{self.name}: shutting down...") self._shutdown_status = True await self.cancel_all_tasks() try: diff --git a/bbot/scanner/manager.py b/bbot/scanner/manager.py index 074d6f6a1..ad722f4fc 100644 --- a/bbot/scanner/manager.py +++ b/bbot/scanner/manager.py @@ -215,13 +215,15 @@ async def handle_event(self, event, **kwargs): # if we discovered something interesting from an internal event, # make sure we preserve its chain of parents parent = event.parent - if parent.internal and ((not event.internal) or event._graph_important): + event_is_graph_worthy = (not event.internal) or event._graph_important + parent_is_graph_worthy = (not parent.internal) or parent._graph_important + if event_is_graph_worthy and not parent_is_graph_worthy: parent_in_report_distance = parent.scope_distance <= self.scan.scope_report_distance if parent_in_report_distance: parent.internal = False if not parent._graph_important: parent._graph_important = True - log.debug(f"Re-queuing internal event {parent} with parent {event}") + log.debug(f"Re-queuing internal event {parent} with parent {event} to prevent graph orphan") await self.emit_event(parent) abort_result = False