-
Notifications
You must be signed in to change notification settings - Fork 562
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
github-actions
committed
Nov 7, 2024
1 parent
3aaccda
commit 2e0c1f2
Showing
6 changed files
with
285 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# This file contains SQLModel (Pydantic + SQLAlchemy) models for BBOT events, scans, and targets. | ||
# Used by the SQL output modules, but portable for outside use. | ||
|
||
import json | ||
import logging | ||
from datetime import datetime | ||
from pydantic import ConfigDict | ||
from typing import List, Optional | ||
from typing_extensions import Annotated | ||
from pydantic.functional_validators import AfterValidator | ||
from sqlmodel import inspect, Column, Field, SQLModel, JSON, String, DateTime as SQLADateTime | ||
|
||
|
||
log = logging.getLogger("bbot_server.models") | ||
|
||
|
||
def naive_datetime_validator(d: datetime): | ||
""" | ||
Converts all dates into UTC, then drops timezone information. | ||
This is needed to prevent inconsistencies in sqlite, because it is timezone-naive. | ||
""" | ||
# drop timezone info | ||
return d.replace(tzinfo=None) | ||
|
||
|
||
NaiveUTC = Annotated[datetime, AfterValidator(naive_datetime_validator)] | ||
|
||
|
||
class CustomJSONEncoder(json.JSONEncoder): | ||
def default(self, obj): | ||
# handle datetime | ||
if isinstance(obj, datetime): | ||
return obj.isoformat() | ||
return super().default(obj) | ||
|
||
|
||
class BBOTBaseModel(SQLModel): | ||
model_config = ConfigDict(extra="ignore") | ||
|
||
def __init__(self, *args, **kwargs): | ||
self._validated = None | ||
super().__init__(*args, **kwargs) | ||
|
||
@property | ||
def validated(self): | ||
try: | ||
if self._validated is None: | ||
self._validated = self.__class__.model_validate(self) | ||
return self._validated | ||
except AttributeError: | ||
return self | ||
|
||
def to_json(self, **kwargs): | ||
return json.dumps(self.validated.model_dump(), sort_keys=True, cls=CustomJSONEncoder, **kwargs) | ||
|
||
@classmethod | ||
def _pk_column_names(cls): | ||
return [column.name for column in inspect(cls).primary_key] | ||
|
||
def __hash__(self): | ||
return hash(self.to_json()) | ||
|
||
def __eq__(self, other): | ||
return hash(self) == hash(other) | ||
|
||
|
||
### EVENT ### | ||
|
||
|
||
class Event(BBOTBaseModel, table=True): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
data = self._get_data(self.data, self.type) | ||
self.data = {self.type: data} | ||
if self.host: | ||
self.reverse_host = self.host[::-1] | ||
|
||
def get_data(self): | ||
return self._get_data(self.data, self.type) | ||
|
||
@staticmethod | ||
def _get_data(data, type): | ||
# handle SIEM-friendly format | ||
if isinstance(data, dict) and list(data) == [type]: | ||
return data[type] | ||
return data | ||
|
||
uuid: str = Field( | ||
primary_key=True, | ||
index=True, | ||
nullable=False, | ||
) | ||
id: str = Field(index=True) | ||
type: str = Field(index=True) | ||
scope_description: str | ||
data: dict = Field(sa_type=JSON) | ||
host: Optional[str] | ||
port: Optional[int] | ||
netloc: Optional[str] | ||
# store the host in reversed form for efficient lookups by domain | ||
reverse_host: Optional[str] = Field(default="", exclude=True, index=True) | ||
resolved_hosts: List = Field(default=[], sa_type=JSON) | ||
dns_children: dict = Field(default={}, sa_type=JSON) | ||
web_spider_distance: int = 10 | ||
scope_distance: int = Field(default=10, index=True) | ||
scan: str = Field(index=True) | ||
timestamp: NaiveUTC = Field(index=True) | ||
parent: str = Field(index=True) | ||
tags: List = Field(default=[], sa_type=JSON) | ||
module: str = Field(index=True) | ||
module_sequence: str | ||
discovery_context: str = "" | ||
discovery_path: List[str] = Field(default=[], sa_type=JSON) | ||
parent_chain: List[str] = Field(default=[], sa_type=JSON) | ||
|
||
|
||
### SCAN ### | ||
|
||
|
||
class Scan(BBOTBaseModel, table=True): | ||
id: str = Field(primary_key=True) | ||
name: str | ||
status: str | ||
started_at: NaiveUTC = Field(index=True) | ||
finished_at: Optional[NaiveUTC] = Field(default=None, sa_column=Column(SQLADateTime, nullable=True, index=True)) | ||
duration_seconds: Optional[float] = Field(default=None) | ||
duration: Optional[str] = Field(default=None) | ||
target: dict = Field(sa_type=JSON) | ||
preset: dict = Field(sa_type=JSON) | ||
|
||
|
||
### TARGET ### | ||
|
||
|
||
class Target(BBOTBaseModel, table=True): | ||
name: str = "Default Target" | ||
strict_scope: bool = False | ||
seeds: List = Field(default=[], sa_type=JSON) | ||
whitelist: List = Field(default=None, sa_type=JSON) | ||
blacklist: List = Field(default=[], sa_type=JSON) | ||
hash: str = Field(sa_column=Column("hash", String, unique=True, primary_key=True, index=True)) | ||
scope_hash: str = Field(sa_column=Column("scope_hash", String, index=True)) | ||
seed_hash: str = Field(sa_column=Column("seed_hashhash", String, index=True)) | ||
whitelist_hash: str = Field(sa_column=Column("whitelist_hash", String, index=True)) | ||
blacklist_hash: str = Field(sa_column=Column("blacklist_hash", String, index=True)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from pathlib import Path | ||
|
||
from bbot.modules.templates.sql import SQLTemplate | ||
|
||
|
||
class SQLite(SQLTemplate): | ||
watched_events = ["*"] | ||
meta = {"description": "sqlite"} | ||
deps_pip = ["sqlmodel", "sqlalchemy-utils", "aiosqlite"] | ||
options = { | ||
"database": "", | ||
} | ||
options_desc = { | ||
"database": "The path to the sqlite database file", | ||
} | ||
|
||
async def setup(self): | ||
db_file = self.config.get("database", "") | ||
if not db_file: | ||
db_file = self.scan.home / "output.sqlite" | ||
db_file = Path(db_file) | ||
if not db_file.is_absolute(): | ||
db_file = self.scan.home / db_file | ||
self.db_file = db_file | ||
self.db_file.parent.mkdir(parents=True, exist_ok=True) | ||
return await super().setup() | ||
|
||
def connection_string(self, mask_password=False): | ||
return f"sqlite+aiosqlite:///{self.db_file}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from sqlmodel import SQLModel | ||
from sqlalchemy.orm import sessionmaker | ||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | ||
from sqlalchemy_utils.functions import database_exists, create_database | ||
|
||
from bbot.db.sql.models import Event, Scan, Target | ||
from bbot.modules.output.base import BaseOutputModule | ||
|
||
|
||
class SQLTemplate(BaseOutputModule): | ||
meta = {"description": "SQL output module template"} | ||
options = { | ||
"protocol": "", | ||
"database": "bbot", | ||
"username": "", | ||
"password": "", | ||
"host": "127.0.0.1", | ||
"port": 0, | ||
} | ||
options_desc = { | ||
"protocol": "The protocol to use to connect to the database", | ||
"database": "The database to use", | ||
"username": "The username to use to connect to the database", | ||
"password": "The password to use to connect to the database", | ||
"host": "The host to use to connect to the database", | ||
"port": "The port to use to connect to the database", | ||
} | ||
|
||
async def setup(self): | ||
self.database = self.config.get("database", "bbot") | ||
self.username = self.config.get("username", "") | ||
self.password = self.config.get("password", "") | ||
self.host = self.config.get("host", "127.0.0.1") | ||
self.port = self.config.get("port", 0) | ||
|
||
self.log.info(f"Connecting to {self.connection_string(mask_password=True)}") | ||
|
||
self.engine = create_async_engine(self.connection_string()) | ||
# Create a session factory bound to the engine | ||
self.async_session = sessionmaker(self.engine, expire_on_commit=False, class_=AsyncSession) | ||
await self.init_database() | ||
return True | ||
|
||
async def handle_event(self, event): | ||
event_obj = Event(**event.json()).validated | ||
|
||
async with self.async_session() as session: | ||
async with session.begin(): | ||
# insert event | ||
session.add(event_obj) | ||
|
||
# if it's a SCAN event, create/update the scan and target | ||
if event_obj.type == "SCAN": | ||
event_data = event_obj.get_data() | ||
if not isinstance(event_data, dict): | ||
raise ValueError(f"Invalid data for SCAN event: {event_data}") | ||
scan = Scan(**event_data).validated | ||
await session.merge(scan) # Insert or update scan | ||
|
||
target_data = event_data.get("target", {}) | ||
if not isinstance(target_data, dict): | ||
raise ValueError(f"Invalid target for SCAN event: {target_data}") | ||
target = Target(**target_data).validated | ||
await session.merge(target) # Insert or update target | ||
|
||
await session.commit() | ||
|
||
async def init_database(self): | ||
async with self.engine.begin() as conn: | ||
# Check if the database exists using the connection's engine URL | ||
if not await conn.run_sync(lambda sync_conn: database_exists(sync_conn.engine.url)): | ||
await conn.run_sync(lambda sync_conn: create_database(sync_conn.engine.url)) | ||
# Create all tables | ||
await conn.run_sync(SQLModel.metadata.create_all) | ||
|
||
def connection_string(self, mask_password=False): | ||
connection_string = f"{self.protocol}://" | ||
if self.username: | ||
password = self.password | ||
if mask_password: | ||
password = "****" | ||
connection_string += f"{self.username}:{password}" | ||
if self.host: | ||
connection_string += f"@{self.host}" | ||
if self.port: | ||
connection_string += f":{self.port}" | ||
if self.database: | ||
connection_string += f"/{self.database}" | ||
return connection_string |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import sqlite3 | ||
from .base import ModuleTestBase | ||
|
||
|
||
class TestSQLite(ModuleTestBase): | ||
targets = ["evilcorp.com"] | ||
|
||
def check(self, module_test, events): | ||
sqlite_output_file = module_test.scan.home / "output.sqlite" | ||
assert sqlite_output_file.exists(), "SQLite output file not found" | ||
with sqlite3.connect(sqlite_output_file) as db: | ||
cursor = db.cursor() | ||
cursor.execute("SELECT * FROM event") | ||
assert len(cursor.fetchall()) > 0, "No events found in SQLite database" | ||
cursor.execute("SELECT * FROM scan") | ||
assert len(cursor.fetchall()) > 0, "No scans found in SQLite database" | ||
cursor.execute("SELECT * FROM target") | ||
assert len(cursor.fetchall()) > 0, "No targets found in SQLite database" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters