Skip to content

Commit

Permalink
more WIP web engine
Browse files Browse the repository at this point in the history
  • Loading branch information
TheTechromancer committed May 3, 2024
1 parent 886c189 commit 25e971c
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 56 deletions.
84 changes: 78 additions & 6 deletions bbot/core/helpers/web/engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import ssl
import anyio
import httpx
import asyncio
import logging
import traceback
from httpx._models import Cookies
from socksio.exceptions import SOCKSError
from contextlib import asynccontextmanager

from bbot.core.engine import EngineServer
from bbot.core.helpers.misc import bytes_to_human, human_to_bytes, get_exception_chain

log = logging.getLogger("bbot.core.helpers.web.engine")


class DummyCookies(Cookies):
Expand Down Expand Up @@ -31,7 +39,6 @@ class BBOTAsyncClient(httpx.AsyncClient):

def __init__(self, *args, **kwargs):
self._config = kwargs.pop("_config")
web_requests_per_second = self._config.get("web_requests_per_second", 100)

http_debug = self._config.get("http_debug", None)
if http_debug:
Expand Down Expand Up @@ -84,6 +91,7 @@ class HTTPEngine(EngineServer):
0: "request",
1: "request_batch",
2: "request_custom_batch",
3: "download",
99: "_mock",
}

Expand Down Expand Up @@ -145,14 +153,80 @@ async def request(self, *args, **kwargs):
async with self._acatch(url, raise_error):
if self.http_debug:
logstr = f"Web request: {str(args)}, {str(kwargs)}"
log.trace(logstr)
self.log.trace(logstr)
response = await client.request(*args, **kwargs)
if self.http_debug:
log.trace(
self.log.trace(
f"Web response from {url}: {response} (Length: {len(response.content)}) headers: {response.headers}"
)
return response

async def request_batch(self, urls, *args, threads=10, **kwargs):
tasks = {}

def new_task(url):
task = asyncio.create_task(self.request(url, *args, **kwargs))
tasks[task] = url

urls = list(urls)
for _ in range(threads): # Start initial batch of tasks
if urls: # Ensure there are args to process
new_task(urls.pop(0))

while tasks: # While there are tasks pending
# Wait for the first task to complete
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

for task in done:
results = task.result()
url = tasks.pop(task)

if results:
yield (url, results)

if urls: # Start a new task for each one completed, if URLs remain
new_task(urls.pop(0))

async def download(self, url, **kwargs):
follow_redirects = kwargs.pop("follow_redirects", True)
filename = kwargs.pop("filename")
max_size = kwargs.pop("max_size", None)
warn = kwargs.pop("warn", True)
raise_error = kwargs.pop("raise_error", False)
if max_size is not None:
max_size = human_to_bytes(max_size)
kwargs["follow_redirects"] = follow_redirects
if not "method" in kwargs:
kwargs["method"] = "GET"
try:
total_size = 0
chunk_size = 8192

async with self._acatch(url, raise_error=True), self.web_client.stream(url=url, **kwargs) as response:
status_code = getattr(response, "status_code", 0)
self.log.debug(f"Download result: HTTP {status_code}")
if status_code != 0:
response.raise_for_status()
with open(filename, "wb") as f:
agen = response.aiter_bytes(chunk_size=chunk_size)
async for chunk in agen:
if max_size is not None and total_size + chunk_size > max_size:
self.log.verbose(
f"Filesize of {url} exceeds {bytes_to_human(max_size)}, file will be truncated"
)
agen.aclose()
break
total_size += chunk_size
f.write(chunk)
return True
except httpx.HTTPError as e:
log_fn = self.log.verbose
if warn:
log_fn = self.log.warning
log_fn(f"Failed to download {url}: {e}")
if raise_error:
raise

def ssl_context_noverify(self):
if self._ssl_context_noverify is None:
ssl_context = ssl.create_default_context()
Expand Down Expand Up @@ -217,9 +291,7 @@ async def _acatch(self, url, raise_error):
log.trace(traceback.format_exc())
except BaseException as e:
# don't log if the error is the result of an intentional cancellation
if not any(
isinstance(_e, asyncio.exceptions.CancelledError) for _e in self.parent_helper.get_exception_chain(e)
):
if not any(isinstance(_e, asyncio.exceptions.CancelledError) for _e in get_exception_chain(e)):
log.trace(f"Unhandled exception with request to URL: {url}: {e}")
log.trace(traceback.format_exc())
raise
60 changes: 13 additions & 47 deletions bbot/core/helpers/web/web.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
import re
import anyio
import httpx
import asyncio
import logging
import warnings
import traceback
from pathlib import Path
from bs4 import BeautifulSoup

from socksio.exceptions import SOCKSError

from bbot.core.engine import EngineClient
from bbot.errors import WordlistError, CurlError
from bbot.core.helpers.ratelimiter import RateLimiter

from bs4 import MarkupResemblesLocatorWarning
from bs4.builder import XMLParsedAsHTMLWarning
Expand Down Expand Up @@ -101,9 +95,16 @@ async def request(self, *args, **kwargs):
Note:
If the web request fails, it will return None unless `raise_error` is `True`.
"""
self.log.critical(f"CLIENT {args} / {kwargs}")
return await self.run_and_return("request", *args, **kwargs)

async def request_batch(self, urls, *args, **kwargs):
async for _ in self.run_and_yield("request_batch", urls, *args, **kwargs):
yield _

async def request_custom_batch(self, urls_and_args):
async for _ in self.run_and_yield("request_custom_batch", urls_and_args):
yield _

async def download(self, url, **kwargs):
"""
Asynchronous function for downloading files from a given URL. Supports caching with an optional
Expand All @@ -129,56 +130,21 @@ async def download(self, url, **kwargs):
"""
success = False
filename = kwargs.pop("filename", self.parent_helper.cache_filename(url))
follow_redirects = kwargs.pop("follow_redirects", True)
filename = Path(filename).resolve()
kwargs["filename"] = filename
max_size = kwargs.pop("max_size", None)
warn = kwargs.pop("warn", True)
raise_error = kwargs.pop("raise_error", False)
if max_size is not None:
max_size = self.parent_helper.human_to_bytes(max_size)
kwargs["max_size"] = max_size
cache_hrs = float(kwargs.pop("cache_hrs", -1))
total_size = 0
chunk_size = 8192
log.debug(f"Downloading file from {url} with cache_hrs={cache_hrs}")
if cache_hrs > 0 and self.parent_helper.is_cached(url):
log.debug(f"{url} is cached at {self.parent_helper.cache_filename(url)}")
success = True
else:
# kwargs["raise_error"] = True
# kwargs["stream"] = True
kwargs["follow_redirects"] = follow_redirects
if not "method" in kwargs:
kwargs["method"] = "GET"
try:
async with self._acatch(url, raise_error=True), self.AsyncClient().stream(
url=url, **kwargs
) as response:
status_code = getattr(response, "status_code", 0)
log.debug(f"Download result: HTTP {status_code}")
if status_code != 0:
response.raise_for_status()
with open(filename, "wb") as f:
agen = response.aiter_bytes(chunk_size=chunk_size)
async for chunk in agen:
if max_size is not None and total_size + chunk_size > max_size:
log.verbose(
f"Filesize of {url} exceeds {self.parent_helper.bytes_to_human(max_size)}, file will be truncated"
)
agen.aclose()
break
total_size += chunk_size
f.write(chunk)
success = True
except httpx.HTTPError as e:
log_fn = log.verbose
if warn:
log_fn = log.warning
log_fn(f"Failed to download {url}: {e}")
if raise_error:
raise
return
success = await self.run_and_return("download", url, **kwargs)

if success:
return filename.resolve()
return filename

async def wordlist(self, path, lines=None, **kwargs):
"""
Expand Down
24 changes: 21 additions & 3 deletions bbot/test/test_step_1/test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,28 @@


@pytest.mark.asyncio
async def test_web_engine(bbot_scanner):
async def test_web_engine(bbot_scanner, bbot_httpserver):

url = bbot_httpserver.url_for("/test")
bbot_httpserver.expect_request(uri="/test").respond_with_data("hello_there")

scan = bbot_scanner()
response = await scan.helpers.request("http://example.com")
log.critical(response)

# request
response = await scan.helpers.request(url)
assert response.status_code > 0
assert response.text == "hello_there"

# request_batch
responses = [r async for r in scan.helpers.request_batch([url] * 100)]
assert len(responses) == 100
assert all([r[0] == url for r in responses])
assert all([r[1].status_code > 0 and r[1].text == "hello_there" for r in responses])

# download
filename = await scan.helpers.download(url)
file_content = open(filename).read()
assert file_content == "hello_there"


@pytest.mark.asyncio
Expand Down

0 comments on commit 25e971c

Please sign in to comment.