Skip to content

Commit

Permalink
Statement reload with DB clearing
Browse files Browse the repository at this point in the history
  • Loading branch information
raulikak committed Apr 30, 2024
1 parent a9eeb3b commit e71a8e3
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 3 deletions.
11 changes: 11 additions & 0 deletions tcsfw/client_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def api_post(self, request: APIRequest, data: Optional[BinaryIO]) -> Dict:
self.system_reset(param.get("evidence", {}), include_all=param.get("include_all", False))
if param.get("dump_all", False):
r = {"events": list(self.api_iterate_all(request.change_path(".")))}
elif path == "reload":
# reload is actually exit
self.api_exit(request, data)
elif path.startswith("event/"):
e_name = path[6:]
e_type = EventMap.get_event_class(e_name)
Expand All @@ -155,6 +158,14 @@ def api_post(self, request: APIRequest, data: Optional[BinaryIO]) -> Dict:
raise FileNotFoundError("Unknown API endpoint")
return r

def api_exit(self, _request: APIRequest, data: bytes) -> Dict:
"""Reload model"""
param = json.loads(data) if data else {}
clear_db = bool(param.get("clear_db", False))
if clear_db:
self.registry.clear_database()
return {}

def api_post_file(self, request: APIRequest, data_file: pathlib.Path) -> Dict:
"""Post API data in ZIP file"""
path = request.path
Expand Down
3 changes: 3 additions & 0 deletions tcsfw/entity_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def put_event(self, event: Event):
"""Store an event"""
raise NotImplementedError()

def clear_database(self):
"""Clear the database, from the disk"""


class InMemoryDatabase(EntityDatabase):
"""Store and retrieve events, later entities, etc."""
Expand Down
23 changes: 20 additions & 3 deletions tcsfw/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import pathlib
import sys
import tempfile
import traceback
from typing import BinaryIO, Dict, Optional, Tuple, List
Expand Down Expand Up @@ -76,6 +77,7 @@ async def start_server(self):
web.get('/api1/ping', self.handle_ping), # ping for health check
web.get('/api1/proxy/{tail:.+}', self.handle_login), # query proxy configuration
web.get('/api1/{tail:.+}', self.handle_http),
web.post('/api1/reload/{tail:.+}', self.handle_reload), # reload, kill the process
web.post('/api1/{tail:.+}', self.handle_http),
])
rr = web.AppRunner(app)
Expand Down Expand Up @@ -115,7 +117,7 @@ async def handle_ping(self, _request: web.Request):
"""Handle ping request"""
return web.Response(text="{}")

async def handle_http(self, request):
async def handle_http(self, request: web.Request):
"""Handle normal HTTP GET or POST request"""
try:
self.check_permission(request)
Expand Down Expand Up @@ -200,7 +202,7 @@ async def api_post_zip(self, api_request: APIRequest, request):
res = self.api.api_post_file(api_request, pathlib.Path(temp_dir))
return res

async def handle_ws(self, request):
async def handle_ws(self, request: web.Request):
"""Handle websocket HTTP request"""
assert request.path_qs.startswith("/api1/ws/")
req = APIRequest.parse(request.path_qs[9:])
Expand Down Expand Up @@ -245,7 +247,7 @@ async def receive_loop():
self.channels.remove(channel)
return ws

async def handle_login(self, request):
async def handle_login(self, request: web.Request):
"""Handle login or proxy query, which is launcher job. This should only be used in development."""
req = APIRequest.parse(request.path_qs)
try:
Expand All @@ -258,6 +260,21 @@ async def handle_login(self, request):
traceback.print_exc()
return web.Response(status=500)

async def handle_reload(self, request: web.Request):
"""Handle reload request"""
self.check_permission(request)
req = APIRequest.parse(request.path_qs)
data = await request.content.read() if request.content else b""
res = self.api.api_exit(req, data)

# reload means exiting this process, delay it for response to be sent
def do_exit():
# return code 0 for successful exit
sys.exit(0) # pylint: disable=consider-using-sys-exit

self.loop.call_later(1, do_exit)
return web.Response(text=json.dumps(res))

def dump_model(self, channel: WebsocketChannel):
"""Dump the whole model into channel"""
if not channel.subscribed:
Expand Down
4 changes: 4 additions & 0 deletions tcsfw/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,12 @@ async def handle_login(self, request: web.Request):
if request.path.startswith("/login/statement/"):
app = request.path[17:]
elif request.path.startswith("/api1/proxy/statement/"):
# NOTE: Remove /statement/ parts at some point
app = request.path[22:]
use_api_key = True
elif request.path.startswith("/api1/proxy/"):
app = request.path[12:]
use_api_key = True
else:
raise FileNotFoundError("Unexpected statement path")

Expand Down
5 changes: 5 additions & 0 deletions tcsfw/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def reset(self, evidence_filter: Dict[EvidenceSource, bool] = None, enable_all=F
self.logging.reset()
return self

def clear_database(self) -> Self:
"""Clear the database, from the disk"""
self.database.clear_database()
return self

def apply_all_events(self) -> Self:
"""Apply all stored events, after reset"""
while True:
Expand Down
14 changes: 14 additions & 0 deletions tcsfw/sql_database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""SQL database by SQLAlchemy"""

import json
import os
import pathlib
from typing import Any, Iterator, List, Optional, Dict, Tuple, Set
from urllib.parse import urlparse

from sqlalchemy import Boolean, Column, Integer, String, create_engine, delete, select
from sqlalchemy.ext.declarative import declarative_base
Expand Down Expand Up @@ -50,6 +53,7 @@ class SQLDatabase(EntityDatabase, ModelListener):
"""Use SQL database for storage"""
def __init__(self, db_uri: str):
super().__init__()
self.db_uri = db_uri
self.engine = create_engine(db_uri)
Base.metadata.create_all(self.engine)
self.db_conn = self.engine.connect()
Expand All @@ -70,6 +74,16 @@ def __init__(self, db_uri: str):
self.pending_batch = []
self.pending_source_ids = set()

def clear_database(self):
# check if DB is a local file
self.engine.dispose()
u = urlparse(self.db_uri)
if u.scheme.startswith("sqlite") and u.path:
path = pathlib.Path(u.path[1:]) if u.path.startswith("/") else pathlib.Path(u.path)
self.logger.info("Deleting DB file %s if it exists", path)
if path.exists():
os.remove(path)

def _fill_cache(self):
"""Fill entity cache from database"""
with Session(self.engine) as ses:
Expand Down

0 comments on commit e71a8e3

Please sign in to comment.