Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Regexes #1278

Merged
merged 21 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions bbot/core/helpers/helper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import os
import asyncio
import logging
from pathlib import Path
import multiprocessing as mp
from functools import partial
from cloudcheck import cloud_providers
from concurrent.futures import ProcessPoolExecutor

from . import misc
from .dns import DNSHelper
from .web import WebHelper
from .diff import HttpCompare
from .regex import RegexHelper
from .wordcloud import WordCloud
from .interactsh import Interactsh
from ...scanner.target import Target
Expand Down Expand Up @@ -65,8 +70,21 @@ def __init__(self, preset):
self.mkdir(self.tools_dir)
self.mkdir(self.lib_dir)

self._loop = None

# multiprocessing thread pool
start_method = mp.get_start_method()
if start_method != "spawn":
self.warning(f"Multiprocessing spawn method is set to {start_method}.")

# we spawn 1 fewer processes than cores
# this helps to avoid locking up the system or competing with the main python process for cpu time
num_processes = max(1, mp.cpu_count() - 1)
self.process_pool = ProcessPoolExecutor(max_workers=num_processes)

self.cloud = cloud_providers

self.re = RegexHelper(self)
self.dns = DNSHelper(self)
self.web = WebHelper(self)
self.depsinstaller = DepsInstaller(self)
Expand Down Expand Up @@ -103,6 +121,38 @@ def config(self):
def scan(self):
return self.preset.scan

@property
def loop(self):
"""
Get the current event loop
"""
if self._loop is None:
self._loop = asyncio.get_running_loop()
return self._loop

def run_in_executor(self, callback, *args, **kwargs):
"""
Run a synchronous task in the event loop's default thread pool executor

Examples:
Execute callback:
>>> result = await self.helpers.run_in_executor(callback_fn, arg1, arg2)
"""
callback = partial(callback, **kwargs)
return self.loop.run_in_executor(None, callback, *args)

def run_in_executor_mp(self, callback, *args, **kwargs):
"""
Same as run_in_executor() except with a process pool executor
Use only in cases where callback is CPU-bound

Examples:
Execute callback:
>>> result = await self.helpers.run_in_executor_mp(callback_fn, arg1, arg2)
"""
callback = partial(callback, **kwargs)
return self.loop.run_in_executor(self.process_pool, callback, *args)

@property
def in_tests(self):
return os.environ.get("BBOT_TESTING", "") == "True"
Expand Down
2 changes: 1 addition & 1 deletion bbot/core/helpers/misc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import re
import sys
import json
import random
import string
import asyncio
import logging
import ipaddress
import regex as re
import subprocess as sp
from pathlib import Path
from contextlib import suppress
Expand Down
72 changes: 72 additions & 0 deletions bbot/core/helpers/regex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import regex as re
from . import misc


class RegexHelper:
"""
Class for misc CPU-intensive regex operations

Offloads regex processing to other CPU cores via GIL release + thread pool

For quick, one-off regexes, you don't need to use this helper.
Only use this helper if you're searching large bodies of text
or if your regex is CPU-intensive
"""

def __init__(self, parent_helper):
self.parent_helper = parent_helper

def ensure_compiled_regex(self, r):
"""
Make sure a regex has been compiled
"""
if not isinstance(r, re.Pattern):
raise ValueError("Regex must be compiled first!")

def compile(self, *args, **kwargs):
return re.compile(*args, **kwargs)

async def search(self, compiled_regex, *args, **kwargs):
self.ensure_compiled_regex(compiled_regex)
return await self.parent_helper.run_in_executor(compiled_regex.search, *args, **kwargs)

async def findall(self, compiled_regex, *args, **kwargs):
self.ensure_compiled_regex(compiled_regex)
return await self.parent_helper.run_in_executor(compiled_regex.findall, *args, **kwargs)

async def finditer(self, compiled_regex, *args, **kwargs):
self.ensure_compiled_regex(compiled_regex)
return await self.parent_helper.run_in_executor(self._finditer, compiled_regex, *args, **kwargs)

async def finditer_multi(self, compiled_regexes, *args, **kwargs):
"""
Same as finditer() but with multiple regexes
"""
for r in compiled_regexes:
self.ensure_compiled_regex(r)
return await self.parent_helper.run_in_executor(self._finditer_multi, compiled_regexes, *args, **kwargs)

def _finditer_multi(self, compiled_regexes, *args, **kwargs):
matches = []
for r in compiled_regexes:
for m in r.finditer(*args, **kwargs):
matches.append(m)
return matches

def _finditer(self, compiled_regex, *args, **kwargs):
return list(compiled_regex.finditer(*args, **kwargs))

async def extract_params_html(self, *args, **kwargs):
return await self.parent_helper.run_in_executor(misc.extract_params_html, *args, **kwargs)

async def extract_emails(self, *args, **kwargs):
return await self.parent_helper.run_in_executor(misc.extract_emails, *args, **kwargs)

async def search_dict_values(self, *args, **kwargs):
def _search_dict_values(*_args, **_kwargs):
return list(misc.search_dict_values(*_args, **_kwargs))

return await self.parent_helper.run_in_executor(_search_dict_values, *args, **kwargs)

async def recursive_decode(self, *args, **kwargs):
return await self.parent_helper.run_in_executor(misc.recursive_decode, *args, **kwargs)
6 changes: 5 additions & 1 deletion bbot/core/helpers/regexes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import re
import regex as re
from collections import OrderedDict

# for extracting words from strings
Expand Down Expand Up @@ -104,3 +104,7 @@

_extract_host_regex = r"(?:[a-z0-9]{1,20}://)?(?:[^?]*@)?(" + valid_netloc + ")"
extract_host_regex = re.compile(_extract_host_regex, re.I)

# for use in recursive_decode()
encoded_regex = re.compile(r"%[0-9a-fA-F]{2}|\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8}|\\[ntrbv]")
backslash_regex = re.compile(r"(?P<slashes>\\+)(?P<char>[ntrvb])")
4 changes: 2 additions & 2 deletions bbot/modules/ajaxpro.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import re
import regex as re
from bbot.modules.base import BaseModule


Expand Down Expand Up @@ -38,7 +38,7 @@ async def handle_event(self, event):
elif event.type == "HTTP_RESPONSE":
resp_body = event.data.get("body", None)
if resp_body:
ajaxpro_regex_result = self.ajaxpro_regex.search(resp_body)
ajaxpro_regex_result = await self.helpers.re.search(self.ajaxpro_regex, resp_body)
if ajaxpro_regex_result:
ajax_pro_path = ajaxpro_regex_result.group(0)
await self.emit_event(
Expand Down
6 changes: 3 additions & 3 deletions bbot/modules/azure_tenant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import re
import regex as re
from contextlib import suppress

from bbot.modules.base import BaseModule
Expand All @@ -25,7 +25,7 @@ async def handle_event(self, event):

tenant_id = None
authorization_endpoint = openid_config.get("authorization_endpoint", "")
matches = self.helpers.regexes.uuid_regex.findall(authorization_endpoint)
matches = await self.helpers.re.findall(self.helpers.regexes.uuid_regex, authorization_endpoint)
if matches:
tenant_id = matches[0]

Expand Down Expand Up @@ -86,7 +86,7 @@ async def query(self, domain):
if status_code not in (200, 421):
self.verbose(f'Error retrieving azure_tenant domains for "{domain}" (status code: {status_code})')
return set(), dict()
found_domains = list(set(self.d_xml_regex.findall(r.text)))
found_domains = list(set(await self.helpers.re.findall(self.d_xml_regex, r.text)))
domains = set()

for d in found_domains:
Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/badsecrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def handle_event(self, event):
resp_cookies[c2[0]] = c2[1]
if resp_body or resp_cookies:
try:
r_list = await self.scan.run_in_executor_mp(
r_list = await self.helpers.run_in_executor_mp(
carve_all_modules,
body=resp_body,
headers=resp_headers,
Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/bevigil.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def handle_event(self, event):
if self.urls:
urls = await self.query(query, request_fn=self.request_urls, parse_fn=self.parse_urls)
if urls:
for parsed_url in await self.scan.run_in_executor_mp(self.helpers.validators.collapse_urls, urls):
for parsed_url in await self.helpers.run_in_executor_mp(self.helpers.validators.collapse_urls, urls):
await self.emit_event(parsed_url.geturl(), "URL_UNVERIFIED", source=event)

async def request_subdomains(self, query):
Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/dehashed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def handle_event(self, event):
for entry in entries:
# we have to clean up the email field because dehashed does a poor job of it
email_str = entry.get("email", "").replace("\\", "")
found_emails = list(self.helpers.extract_emails(email_str))
found_emails = list(await self.helpers.re.extract_emails(email_str))
if not found_emails:
self.debug(f"Invalid email from dehashed.com: {email_str}")
continue
Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/emailformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ async def handle_event(self, event):
r = await self.request_with_fail_count(url)
if not r:
return
for email in self.helpers.extract_emails(r.text):
for email in await self.helpers.re.extract_emails(r.text):
if email.endswith(query):
await self.emit_event(email, "EMAIL_ADDRESS", source=event)
3 changes: 1 addition & 2 deletions bbot/modules/hunt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# adapted from https://github.com/bugcrowd/HUNT

from bbot.modules.base import BaseModule
from bbot.core.helpers.misc import extract_params_html

hunt_param_dict = {
"Command Injection": [
Expand Down Expand Up @@ -281,7 +280,7 @@ class hunt(BaseModule):

async def handle_event(self, event):
body = event.data.get("body", "")
for p in extract_params_html(body):
for p in await self.helpers.re.extract_params_html(body):
for k in hunt_param_dict.keys():
if p.lower() in hunt_param_dict[k]:
description = f"Found potential {k.upper()} parameter [{p}]"
Expand Down
5 changes: 3 additions & 2 deletions bbot/modules/internal/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ class cloud(InterceptModule):

async def setup(self):
self.dummy_modules = {}
for provider_name in self.helpers.cloud.providers:
for provider_name, provider in self.helpers.cloud.providers.items():
self.dummy_modules[provider_name] = self.scan._make_dummy_module(f"cloud_{provider_name}", _type="scan")

return True

async def filter_event(self, event):
Expand Down Expand Up @@ -43,7 +44,7 @@ async def handle_event(self, event, kwargs):
for sig in sigs:
matches = []
if event.type == "HTTP_RESPONSE":
matches = sig.findall(event.data.get("body", ""))
matches = await self.helpers.re.findall(sig, event.data.get("body", ""))
elif event.type.startswith("DNS_NAME"):
for host in hosts_to_check:
match = sig.match(host)
Expand Down
17 changes: 9 additions & 8 deletions bbot/modules/internal/excavate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import html
import base64
import jwt as j
import regex as re
from urllib.parse import urljoin

from bbot.core.helpers.regexes import _email_regex, dns_name_regex
Expand All @@ -14,6 +14,7 @@ class BaseExtractor:

def __init__(self, excavate):
self.excavate = excavate
self.helpers = excavate.helpers
self.compiled_regexes = {}
for rname, r in self.regexes.items():
self.compiled_regexes[rname] = re.compile(r)
Expand All @@ -29,7 +30,7 @@ async def _search(self, content, event, **kwargs):
for name, regex in self.compiled_regexes.items():
# yield to event loop
await self.excavate.helpers.sleep(0)
for result in regex.findall(content):
for result in await self.helpers.re.findall(regex, content):
yield result, name

async def report(self, result, name, event):
Expand All @@ -39,14 +40,14 @@ async def report(self, result, name, event):
class CSPExtractor(BaseExtractor):
regexes = {"CSP": r"(?i)(?m)Content-Security-Policy:.+$"}

def extract_domains(self, csp):
domains = dns_name_regex.findall(csp)
async def extract_domains(self, csp):
domains = await self.helpers.re.findall(dns_name_regex, csp)
unique_domains = set(domains)
return unique_domains

async def search(self, content, event, **kwargs):
async for csp, name in self._search(content, event, **kwargs):
extracted_domains = self.extract_domains(csp)
extracted_domains = await self.extract_domains(csp)
for domain in extracted_domains:
await self.report(domain, event, **kwargs)

Expand Down Expand Up @@ -125,7 +126,7 @@ async def _search(self, content, event, **kwargs):
for name, regex in self.compiled_regexes.items():
# yield to event loop
await self.excavate.helpers.sleep(0)
for result in regex.findall(content):
for result in await self.helpers.re.findall(regex, content):
if name.startswith("full"):
protocol, other = result
result = f"{protocol}://{other}"
Expand Down Expand Up @@ -386,7 +387,7 @@ async def handle_event(self, event):
else:
self.verbose(f"Exceeded max HTTP redirects ({self.max_redirects}): {location}")

body = self.helpers.recursive_decode(event.data.get("body", ""))
body = await self.helpers.re.recursive_decode(event.data.get("body", ""))

await self.search(
body,
Expand All @@ -404,7 +405,7 @@ async def handle_event(self, event):
consider_spider_danger=True,
)

headers = self.helpers.recursive_decode(event.data.get("raw_header", ""))
headers = await self.helpers.re.recursive_decode(event.data.get("raw_header", ""))
await self.search(
headers,
[self.hostname, self.url, self.email, self.error_extractor, self.jwt, self.serialization, self.csp],
Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/massdns.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import json
import random
import subprocess
import regex as re

from bbot.modules.templates.subdomain_enum import subdomain_enum

Expand Down
2 changes: 1 addition & 1 deletion bbot/modules/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def getoidc(self, url):
return url, token_endpoint, results
if json and isinstance(json, dict):
token_endpoint = json.get("token_endpoint", "")
for found in self.helpers.search_dict_values(json, *self.regexes):
for found in await self.helpers.re.search_dict_values(json, *self.regexes):
results.add(found)
results -= {token_endpoint}
return url, token_endpoint, results
Expand Down
Loading
Loading