diff --git a/bbot/core/helpers/web/engine.py b/bbot/core/helpers/web/engine.py index 9e30dbb8b..3a13bbb2d 100644 --- a/bbot/core/helpers/web/engine.py +++ b/bbot/core/helpers/web/engine.py @@ -1,9 +1,17 @@ import ssl +import anyio import httpx +import asyncio +import logging +import traceback from httpx._models import Cookies +from socksio.exceptions import SOCKSError from contextlib import asynccontextmanager from bbot.core.engine import EngineServer +from bbot.core.helpers.misc import bytes_to_human, human_to_bytes, get_exception_chain + +log = logging.getLogger("bbot.core.helpers.web.engine") class DummyCookies(Cookies): @@ -31,7 +39,6 @@ class BBOTAsyncClient(httpx.AsyncClient): def __init__(self, *args, **kwargs): self._config = kwargs.pop("_config") - web_requests_per_second = self._config.get("web_requests_per_second", 100) http_debug = self._config.get("http_debug", None) if http_debug: @@ -84,6 +91,7 @@ class HTTPEngine(EngineServer): 0: "request", 1: "request_batch", 2: "request_custom_batch", + 3: "download", 99: "_mock", } @@ -145,14 +153,80 @@ async def request(self, *args, **kwargs): async with self._acatch(url, raise_error): if self.http_debug: logstr = f"Web request: {str(args)}, {str(kwargs)}" - log.trace(logstr) + self.log.trace(logstr) response = await client.request(*args, **kwargs) if self.http_debug: - log.trace( + self.log.trace( f"Web response from {url}: {response} (Length: {len(response.content)}) headers: {response.headers}" ) return response + async def request_batch(self, urls, *args, threads=10, **kwargs): + tasks = {} + + def new_task(url): + task = asyncio.create_task(self.request(url, *args, **kwargs)) + tasks[task] = url + + urls = list(urls) + for _ in range(threads): # Start initial batch of tasks + if urls: # Ensure there are args to process + new_task(urls.pop(0)) + + while tasks: # While there are tasks pending + # Wait for the first task to complete + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + for task in done: + results = task.result() + url = tasks.pop(task) + + if results: + yield (url, results) + + if urls: # Start a new task for each one completed, if URLs remain + new_task(urls.pop(0)) + + async def download(self, url, **kwargs): + follow_redirects = kwargs.pop("follow_redirects", True) + filename = kwargs.pop("filename") + max_size = kwargs.pop("max_size", None) + warn = kwargs.pop("warn", True) + raise_error = kwargs.pop("raise_error", False) + if max_size is not None: + max_size = human_to_bytes(max_size) + kwargs["follow_redirects"] = follow_redirects + if not "method" in kwargs: + kwargs["method"] = "GET" + try: + total_size = 0 + chunk_size = 8192 + + async with self._acatch(url, raise_error=True), self.web_client.stream(url=url, **kwargs) as response: + status_code = getattr(response, "status_code", 0) + self.log.debug(f"Download result: HTTP {status_code}") + if status_code != 0: + response.raise_for_status() + with open(filename, "wb") as f: + agen = response.aiter_bytes(chunk_size=chunk_size) + async for chunk in agen: + if max_size is not None and total_size + chunk_size > max_size: + self.log.verbose( + f"Filesize of {url} exceeds {bytes_to_human(max_size)}, file will be truncated" + ) + agen.aclose() + break + total_size += chunk_size + f.write(chunk) + return True + except httpx.HTTPError as e: + log_fn = self.log.verbose + if warn: + log_fn = self.log.warning + log_fn(f"Failed to download {url}: {e}") + if raise_error: + raise + def ssl_context_noverify(self): if self._ssl_context_noverify is None: ssl_context = ssl.create_default_context() @@ -217,9 +291,7 @@ async def _acatch(self, url, raise_error): log.trace(traceback.format_exc()) except BaseException as e: # don't log if the error is the result of an intentional cancellation - if not any( - isinstance(_e, asyncio.exceptions.CancelledError) for _e in self.parent_helper.get_exception_chain(e) - ): + if not any(isinstance(_e, asyncio.exceptions.CancelledError) for _e in get_exception_chain(e)): log.trace(f"Unhandled exception with request to URL: {url}: {e}") log.trace(traceback.format_exc()) raise diff --git a/bbot/core/helpers/web/web.py b/bbot/core/helpers/web/web.py index 6d44cca61..09bc3b581 100644 --- a/bbot/core/helpers/web/web.py +++ b/bbot/core/helpers/web/web.py @@ -1,18 +1,12 @@ import re -import anyio -import httpx -import asyncio import logging import warnings import traceback from pathlib import Path from bs4 import BeautifulSoup -from socksio.exceptions import SOCKSError - from bbot.core.engine import EngineClient from bbot.errors import WordlistError, CurlError -from bbot.core.helpers.ratelimiter import RateLimiter from bs4 import MarkupResemblesLocatorWarning from bs4.builder import XMLParsedAsHTMLWarning @@ -101,9 +95,16 @@ async def request(self, *args, **kwargs): Note: If the web request fails, it will return None unless `raise_error` is `True`. """ - self.log.critical(f"CLIENT {args} / {kwargs}") return await self.run_and_return("request", *args, **kwargs) + async def request_batch(self, urls, *args, **kwargs): + async for _ in self.run_and_yield("request_batch", urls, *args, **kwargs): + yield _ + + async def request_custom_batch(self, urls_and_args): + async for _ in self.run_and_yield("request_custom_batch", urls_and_args): + yield _ + async def download(self, url, **kwargs): """ Asynchronous function for downloading files from a given URL. Supports caching with an optional @@ -129,56 +130,21 @@ async def download(self, url, **kwargs): """ success = False filename = kwargs.pop("filename", self.parent_helper.cache_filename(url)) - follow_redirects = kwargs.pop("follow_redirects", True) + filename = Path(filename).resolve() + kwargs["filename"] = filename max_size = kwargs.pop("max_size", None) - warn = kwargs.pop("warn", True) - raise_error = kwargs.pop("raise_error", False) if max_size is not None: max_size = self.parent_helper.human_to_bytes(max_size) + kwargs["max_size"] = max_size cache_hrs = float(kwargs.pop("cache_hrs", -1)) - total_size = 0 - chunk_size = 8192 - log.debug(f"Downloading file from {url} with cache_hrs={cache_hrs}") if cache_hrs > 0 and self.parent_helper.is_cached(url): log.debug(f"{url} is cached at {self.parent_helper.cache_filename(url)}") success = True else: - # kwargs["raise_error"] = True - # kwargs["stream"] = True - kwargs["follow_redirects"] = follow_redirects - if not "method" in kwargs: - kwargs["method"] = "GET" - try: - async with self._acatch(url, raise_error=True), self.AsyncClient().stream( - url=url, **kwargs - ) as response: - status_code = getattr(response, "status_code", 0) - log.debug(f"Download result: HTTP {status_code}") - if status_code != 0: - response.raise_for_status() - with open(filename, "wb") as f: - agen = response.aiter_bytes(chunk_size=chunk_size) - async for chunk in agen: - if max_size is not None and total_size + chunk_size > max_size: - log.verbose( - f"Filesize of {url} exceeds {self.parent_helper.bytes_to_human(max_size)}, file will be truncated" - ) - agen.aclose() - break - total_size += chunk_size - f.write(chunk) - success = True - except httpx.HTTPError as e: - log_fn = log.verbose - if warn: - log_fn = log.warning - log_fn(f"Failed to download {url}: {e}") - if raise_error: - raise - return + success = await self.run_and_return("download", url, **kwargs) if success: - return filename.resolve() + return filename async def wordlist(self, path, lines=None, **kwargs): """ diff --git a/bbot/test/test_step_1/test_web.py b/bbot/test/test_step_1/test_web.py index dc9116e0f..aeac2ba2f 100644 --- a/bbot/test/test_step_1/test_web.py +++ b/bbot/test/test_step_1/test_web.py @@ -4,10 +4,28 @@ @pytest.mark.asyncio -async def test_web_engine(bbot_scanner): +async def test_web_engine(bbot_scanner, bbot_httpserver): + + url = bbot_httpserver.url_for("/test") + bbot_httpserver.expect_request(uri="/test").respond_with_data("hello_there") + scan = bbot_scanner() - response = await scan.helpers.request("http://example.com") - log.critical(response) + + # request + response = await scan.helpers.request(url) + assert response.status_code > 0 + assert response.text == "hello_there" + + # request_batch + responses = [r async for r in scan.helpers.request_batch([url] * 100)] + assert len(responses) == 100 + assert all([r[0] == url for r in responses]) + assert all([r[1].status_code > 0 and r[1].text == "hello_there" for r in responses]) + + # download + filename = await scan.helpers.download(url) + file_content = open(filename).read() + assert file_content == "hello_there" @pytest.mark.asyncio