diff --git a/backend/api/auth.py b/backend/api/auth.py index 0c136235..82445ca8 100644 --- a/backend/api/auth.py +++ b/backend/api/auth.py @@ -150,7 +150,10 @@ async def handle(self, authorization: str | None = Header(None)): class SetPassword(APIRequest): """Set a new password for the current (or a given) user. - In order to set a password for another user, the current user must be an admin. + Normal users can only change their own password. + + In order to set a password for another user, + the current user must be an admin, otherwise a 403 error is returned. """ name: str = "password" @@ -163,7 +166,7 @@ async def handle( ) -> Response: if request.login: if not user.is_admin: - raise nebula.UnauthorizedException( + raise nebula.ForbiddenException( "Only admin can change other user's password" ) query = "SELECT meta FROM users WHERE login = $1" diff --git a/backend/api/jobs/jobs.py b/backend/api/jobs/jobs.py index 27cb7d75..bd4408a4 100644 --- a/backend/api/jobs/jobs.py +++ b/backend/api/jobs/jobs.py @@ -1,7 +1,6 @@ import time from typing import Literal -from fastapi import Response from nxtools import slugify from pydantic import Field @@ -95,7 +94,7 @@ class JobsItemModel(RequestModel): class JobsResponseModel(ResponseModel): - jobs: list[JobsItemModel] = Field(default_factory=list) + jobs: list[JobsItemModel] | None = Field(default=None) async def can_user_control_job(user: nebula.User, id_job: int) -> bool: @@ -192,7 +191,7 @@ async def handle( self, request: JobsRequestModel, user: CurrentUser, - ) -> JobsResponseModel | Response: + ) -> JobsResponseModel: if request.abort: await abort_job(request.abort, user) @@ -227,7 +226,7 @@ async def handle( # failed conds.append("j.status IN (3)") elif request.view is None: - return Response(status_code=204) + return JobsResponseModel() if request.asset_ids is not None: ids = ",".join([str(id) for id in request.asset_ids]) diff --git a/backend/api/proxy.py b/backend/api/proxy.py index f71f75b6..0152781c 100644 --- a/backend/api/proxy.py +++ b/backend/api/proxy.py @@ -1,36 +1,26 @@ import os import aiofiles -from fastapi import Header, HTTPException, Request, Response, status -from fastapi.responses import StreamingResponse +from fastapi import HTTPException, Request, Response, status import nebula -from server.dependencies import CurrentUserInQuery +from server.dependencies import CurrentUser from server.request import APIRequest -async def send_bytes_range_requests(file_name: str, start: int, end: int): - """Send a file in chunks using Range Requests specification RFC7233 +class ProxyResponse(Response): + content_type = "video/mp4" - `start` and `end` parameters are inclusive due to specification - """ - CHUNK_SIZE = 1024 * 8 - sent_bytes = 0 - try: - async with aiofiles.open(file_name, mode="rb") as f: - await f.seek(start) - pos = start - while pos < end: - read_size = min(CHUNK_SIZE, end - pos + 1) - data = await f.read(read_size) - yield data - pos += len(data) - sent_bytes += len(data) - finally: - nebula.log.trace( - f"Finished sending file {start}-{end}. Sent {sent_bytes} bytes. Expected {end-start+1} bytes" - ) +async def get_bytes_range(file_name: str, start: int, end: int) -> bytes: + """Get a range of bytes from a file""" + + async with aiofiles.open(file_name, mode="rb") as f: + await f.seek(start) + pos = start + # read_size = min(CHUNK_SIZE, end - pos + 1) + read_size = end - pos + 1 + return await f.read(read_size) def _get_range_header(range_header: str, file_size: int) -> tuple[int, int]: @@ -52,10 +42,13 @@ def _invalid_range(): return start, end -async def range_requests_response(request: Request, file_path: str, content_type: str): +async def range_requests_response( + request: Request, file_path: str, content_type: str +) -> ProxyResponse: """Returns StreamingResponse using Range Requests of a given file""" file_size = os.stat(file_path).st_size + max_chunk_size = 1024 * 1024 # 2MB range_header = request.headers.get("range") headers = { @@ -74,22 +67,21 @@ async def range_requests_response(request: Request, file_path: str, content_type if range_header is not None: start, end = _get_range_header(range_header, file_size) + end = min(end, start + max_chunk_size - 1) size = end - start + 1 headers["content-length"] = str(size) headers["content-range"] = f"bytes {start}-{end}/{file_size}" status_code = status.HTTP_206_PARTIAL_CONTENT - return StreamingResponse( - send_bytes_range_requests(file_path, start, end), + payload = await get_bytes_range(file_path, start, end) + + return ProxyResponse( + content=payload, headers=headers, status_code=status_code, ) -class ProxyResponse(Response): - content_type = "video/mp4" - - class ServeProxy(APIRequest): """Serve a low-res (proxy) media for a given asset. @@ -100,16 +92,15 @@ class ServeProxy(APIRequest): name: str = "proxy" path: str = "/proxy/{id_asset}" title: str = "Serve proxy" - response_class = ProxyResponse methods = ["GET"] async def handle( self, request: Request, id_asset: int, - user: CurrentUserInQuery, - range: str = Header(None), - ): + user: CurrentUser, + ) -> ProxyResponse: + assert user sys_settings = nebula.settings.system proxy_storage_path = nebula.storages[sys_settings.proxy_storage].local_path proxy_path_template = os.path.join(proxy_storage_path, sys_settings.proxy_path) @@ -123,10 +114,10 @@ async def handle( if not os.path.exists(video_path): # maybe return content too? with a placeholder image? - return Response(status_code=404, content="Not found") + raise nebula.NotFoundException("Proxy not found") try: return await range_requests_response(request, video_path, "video/mp4") - except Exception: + except Exception as e: nebula.log.traceback("Error serving proxy") - return Response(status_code=500, content="Internal server error") + raise nebula.NebulaException("Error serving proxy") from e diff --git a/backend/api/sessions.py b/backend/api/sessions.py index c598e3ea..04b0930f 100644 --- a/backend/api/sessions.py +++ b/backend/api/sessions.py @@ -16,7 +16,6 @@ class Sessions(APIRequest): name = "sessions" title = "List sessions" - response_model = list[SessionModel] async def handle( self, @@ -55,15 +54,14 @@ class InvalidateSession(APIRequest): """ name = "invalidate_session" - title = "Invalidate a session" - responses = [204, 201] + title = "Invalidate session" + responses = [204] async def handle( self, payload: InvalidateSessionRequest, user: CurrentUser, ) -> Response: - session = await Session.check(payload.token) if session is None: raise nebula.NotFoundException("Session not found") diff --git a/backend/manage b/backend/manage index f831be22..4f4122dd 100755 --- a/backend/manage +++ b/backend/manage @@ -1,5 +1,7 @@ #!/bin/bash +SERVER_TYPE=${NEBULA_SERVER_TYPE:-gunicorn} + if [ $# -ne 1 ]; then echo "Error: a single argument is required" exit 1 @@ -12,21 +14,45 @@ function start_server () { # Run setup to make sure database is up to date python -m setup + - # Start gunicorn - exec gunicorn \ - -k uvicorn.workers.UvicornWorker \ - --log-level warning \ - -b :80 \ - server.server:app + if [ $SERVER_TYPE = "gunicorn" ]; then + exec gunicorn \ + -k uvicorn.workers.UvicornWorker \ + --log-level warning \ + -b :80 \ + server.server:app + elif [ $SERVER_TYPE = "granian" ]; then + exec granian \ + --interface asgi \ + --log-level warning \ + --host 0.0.0.0 \ + --port 80 \ + server.server:app + else + echo "" + echo "Error: invalid server type '$SERVER_TYPE'. Expected 'gunicorn' or 'granian'" + echo "" + exit 1 + fi +} + + +function get_server_pid () { + if [ $SERVER_TYPE = "gunicorn" ]; then + pid=$(ps aux | grep 'gunicorn' | awk '{print $2}') + elif [ $SERVER_TYPE = "granian" ]; then + pid=$(ps aux | grep 'granian' | awk '{print $2}') + fi + echo $pid } + function stop_server () { echo "" echo "SIGTERM signal received. Shutting down..." echo "" - pid=$(ps aux | grep 'gunicorn' | awk '{print $2}') - kill -TERM $pid 2> /dev/null + kill -TERM $(get_server_pid) 2> /dev/null exit 0 } @@ -34,12 +60,12 @@ function reload_server () { echo "" echo "Reloading the server..." echo "" - pid=$(ps aux | grep 'gunicorn' | awk '{print $2}') - kill -HUP $pid 2> /dev/null + kill -HUP $(get_server_pid) 2> /dev/null exit 0 } trap stop_server SIGTERM SIGINT +trap reload_server SIGHUP if [ $1 = "start" ]; then diff --git a/backend/mypy.ini b/backend/mypy.ini index bcc0be60..314cfb65 100644 --- a/backend/mypy.ini +++ b/backend/mypy.ini @@ -1,7 +1,7 @@ [mypy] python_version = 3.10 ignore_missing_imports = false -check_untyped_defs = false +check_untyped_defs = true strict=false files=./**/*.py exclude=(tests/|venv/) diff --git a/backend/nebula/enum.py b/backend/nebula/enum.py index 27669672..3eb34a68 100644 --- a/backend/nebula/enum.py +++ b/backend/nebula/enum.py @@ -2,16 +2,42 @@ class ObjectStatus(enum.IntEnum): + """Object status enumeration. + + This enumeration is used to indicate the status of an object. + Objects can be in one of the following states: + + - OFFLINE: Object is in the database, but not available on the filesystem. + - ONLINE: Object is in the database and available on the filesystem. + - CREATING: Media file exists, but was changed recently, so its metadata + is being updated. + - TRASHED: Object has been marked as deleted, but is still available on + the filesystem. It will be deleted permanently after some time. + - ARCHIVED: Object has been marked as archived, but is still available on + the filesystem. It will be deleted permanently after some time. + - RESET: User has requested to reset the metadata of the object, + this triggers a re-scan of the media file metadata. + - CORRUPTED: Object is corrupted, and cannot be used. + - REMOTE: Object is not available on the filesystem, but is available one + a remote storage (typically a playout item which media file is on a + production storage, but it hasn't been copied to the playout storage yet). + - UNKNOWN: Object status is unknown. + - AIRED: Only for items. Item has been broadcasted. + - ONAIR: Only for items. Item is currently being broadcasted. + - RETRIEVING: Asset is marked for retrieval from a remote/archive storage. + + """ + OFFLINE = 0 ONLINE = 1 - CREATING = 2 # File exists, but was changed recently. - TRASHED = 3 # File has been moved to trash location. - ARCHIVED = 4 # File has been moved to archive location. - RESET = 5 # Reset metadata action has been invoked. + CREATING = 2 + TRASHED = 3 + ARCHIVED = 4 + RESET = 5 CORRUPTED = 6 REMOTE = 7 UNKNOWN = 8 - AIRED = 9 # Auxiliary value. + AIRED = 9 ONAIR = 10 RETRIEVING = 11 diff --git a/backend/nebula/log.py b/backend/nebula/log.py index bea54a0c..c6ff9062 100644 --- a/backend/nebula/log.py +++ b/backend/nebula/log.py @@ -1,4 +1,5 @@ import enum +import logging import sys import traceback @@ -70,3 +71,22 @@ def critical(self, *args, **kwargs): log = Logger() + +# Add custom logging handler to standard logging module +# This allows us to use the standard logging module with +# the same format, log level and consumers as the primary +# Nebula logger. This is useful for 3rd party libraries. + + +class CustomHandler(logging.Handler): + def emit(self, record): + log_message = self.format(record) + name = record.name + log(LogLevel(record.levelno // 10), log_message, user=name) + + +root_logger = logging.getLogger() +root_logger.setLevel(log.level * 10) + +custom_handler = CustomHandler() +root_logger.addHandler(custom_handler) diff --git a/backend/nebula/objects/event.py b/backend/nebula/objects/event.py index 344d8484..6ba1f00f 100644 --- a/backend/nebula/objects/event.py +++ b/backend/nebula/objects/event.py @@ -17,4 +17,7 @@ class Event(BaseObject): } async def delete_children(self): + assert self.connection is not None + assert hasattr(self.connection, "execute") + assert self.id await self.connection.execute("DELETE FROM bins WHERE id_magic = $1", self.id) diff --git a/backend/nebula/plugins/solver.py b/backend/nebula/plugins/solver.py index 782be751..3e822bb5 100644 --- a/backend/nebula/plugins/solver.py +++ b/backend/nebula/plugins/solver.py @@ -210,14 +210,14 @@ async def main(self): item["position"] = i await item.save(notify=False) - if self.bin.id not in self.affected_bins: + if self.bin.id and self.bin.id not in self.affected_bins: self.affected_bins.append(self.bin.id) # save event in case solver updated its metadata await self.event.save() # another paceholder was created, so we need to solve it - if self._solve_next: + if self._solve_next and self._solve_next.id: await self(self._solve_next.id) return diff --git a/backend/nebula/settings/models.py b/backend/nebula/settings/models.py index cd937a62..40e23a41 100644 --- a/backend/nebula/settings/models.py +++ b/backend/nebula/settings/models.py @@ -15,7 +15,7 @@ class CSAlias(SettingsModel): class CSItemModel(SettingsModel): - role: CSItemRole | None = Field(None) + role: CSItemRole | None = Field(default=None) aliases: dict[str, CSAlias] = Field(default_factory=dict) @classmethod @@ -53,40 +53,40 @@ class BaseSystemSettings(SettingsModel): """ site_name: str = Field( - "nebula", + default="nebula", regex=r"^[a-zA-Z0-9_]+$", title="Site name", description="A name used as the site (instance) identification", ) language: LanguageCode = Field( - "en", + default="en", title="Default language", example="en", ) ui_asset_create: bool = Field( - True, + default=True, title="Create assets in UI", description="Allow creating assets in the UI" "(when set to false, assets can only be created via API and watch folders)", ) ui_asset_preview: bool = Field( - True, + default=True, title="Preview assets in UI", description="Allow previewing low-res proxies of assets in the UI", ) ui_asset_upload: bool = Field( - False, + default=False, title="Upload assets in UI", description="Allow uploading asset media files in the UI " "(when set to false, assets can only be uploaded via API and watch folders)", ) subtitle_separator: str = Field( - ": ", + default=": ", title="Subtitle separator", description="String used to separate title and subtitle in displayed title", ) @@ -99,21 +99,25 @@ class SystemSettings(BaseSystemSettings): Contains settings that are used only by the server. """ - proxy_storage: int = Field(1, title="Proxy storage", example=1) - proxy_path: str = Field(".nx/proxy/{id1000:04d}/{id}.mp4") - worker_plugin_storage: int = Field(1) - worker_plugin_path: str = Field(".nx/plugins") - upload_storage: int | None = Field(None) - upload_dir: str | None = Field(None) - upload_base_name: str = Field("{id}") + proxy_storage: int = Field(default=1, title="Proxy storage", example=1) + proxy_path: str = Field(default=".nx/proxy/{id1000:04d}/{id}.mp4") + worker_plugin_storage: int = Field(default=1) + worker_plugin_path: str = Field(default=".nx/plugins") + upload_storage: int | None = Field(default=None) + upload_dir: str | None = Field(default=None) + upload_base_name: str = Field(default="{id}") - smtp_host: str | None = Field(None, title="SMTP host", example="smtp.example.com") - smtp_port: int | None = Field(None, title="SMTP port", example=465) - smtp_user: str | None = Field(None, title="SMTP user", example="smtpuser") - smtp_pass: str | None = Field(None, title="SMTP password", example="smtppass.1") + smtp_host: str | None = Field( + default=None, title="SMTP host", example="smtp.example.com" + ) + smtp_port: int | None = Field(default=None, title="SMTP port", example=465) + smtp_user: str | None = Field(default=None, title="SMTP user", example="smtpuser") + smtp_pass: str | None = Field( + default=None, title="SMTP password", example="smtppass.1" + ) mail_from: str | None = Field( - "Nebula ", + default="Nebula ", title="Mail from", description="Email address used as the sender", example="Nebula ", @@ -185,7 +189,7 @@ class StorageSettings(BaseStorageSettings): class FolderField(SettingsModel): name: str = Field(..., title="Field name") - section: str | None = Field(None, title="Section") + section: str | None = Field(default=None, title="Section") mode: str | None = None format: str | None = None order: str | None = None @@ -212,11 +216,11 @@ class ViewSettings(SettingsModel): id: int = Field(...) name: str = Field(...) position: int = Field(...) - folders: list[int] | None = Field(None) - states: list[int] | None = Field(None) - columns: list[str] | None = Field(None) - conditions: list[str] | None = Field(None) - separator: bool = Field(False) + folders: list[int] | None = Field(default=None) + states: list[int] | None = Field(default=None) + columns: list[str] | None = Field(default=None) + conditions: list[str] | None = Field(default=None) + separator: bool = Field(default=False) DayStart = tuple[int, int] @@ -224,31 +228,31 @@ class ViewSettings(SettingsModel): class AcceptModel(SettingsModel): folders: list[int] | None = Field( - None, + default=None, title="Folders", description="List of folder IDs", ) content_types: list[ContentType] | None = Field( + default_factory=lambda: [ContentType.VIDEO], title="Content types", description="List of content types that are accepted. " "None means all types are accepted.", - default_factory=lambda: [ContentType.VIDEO], ) media_types: list[MediaType] | None = Field( + default_factory=lambda: [MediaType.FILE], title="Media types", description="List of media types that are accepted. " "None means all types are accepted.", - default_factory=lambda: [MediaType.FILE], ) class BasePlayoutChannelSettings(SettingsModel): id: int = Field(...) name: str = Field(...) - fps: float = Field(25.0) + fps: float = Field(default=25.0) plugins: list[str] = Field(default_factory=list) solvers: list[str] = Field(default_factory=list) - day_start: DayStart = Field((7, 0)) + day_start: DayStart = Field(default=(7, 0)) rundown_columns: list[str] = Field(default_factory=list) fields: list[FolderField] = Field( fields="Fields", @@ -260,7 +264,7 @@ class BasePlayoutChannelSettings(SettingsModel): FolderField(name="color"), # to distinguish events in the scheduler view ], ) - send_action: int | None = None + send_action: int | None = Field(default=None) scheduler_accepts: AcceptModel = Field(default_factory=AcceptModel) rundown_accepts: AcceptModel = Field(default_factory=AcceptModel) @@ -285,16 +289,9 @@ class PlayoutChannelSettings(BasePlayoutChannelSettings): # -def find_id(data: list[SettingsModel], id: int) -> SettingsModel | None: - for item in data: - if item.id == id: - return item - return None - - class ServerSettings(SettingsModel): installed: bool = True - system: SystemSettings = Field(default_factory=SystemSettings) + system: SystemSettings = Field(default_factory=lambda: SystemSettings()) storages: list[StorageSettings] = Field(default_factory=list) folders: list[FolderSettings] = Field(default_factory=list) views: list[ViewSettings] = Field(default_factory=list) @@ -306,13 +303,25 @@ class ServerSettings(SettingsModel): playout_channels: list[PlayoutChannelSettings] = Field(default_factory=list) def get_folder(self, id_folder: int) -> FolderSettings | None: - return find_id(self.folders, id_folder) + for item in self.folders: + if item.id == id_folder: + return item + return None def get_view(self, id_view: int) -> ViewSettings | None: - return find_id(self.views, id_view) + for item in self.views: + if item.id == id_view: + return item + return None def get_storage(self, id_storage: int) -> StorageSettings | None: - return find_id(self.storages, id_storage) + for item in self.storages: + if item.id == id_storage: + return item + return None def get_playout_channel(self, id_channel: int) -> PlayoutChannelSettings | None: - return find_id(self.playout_channels, id_channel) + for item in self.playout_channels: + if item.id == id_channel: + return item + return None diff --git a/backend/nebula/storages.py b/backend/nebula/storages.py index 89830935..9a29fffa 100644 --- a/backend/nebula/storages.py +++ b/backend/nebula/storages.py @@ -14,9 +14,9 @@ def __init__(self, storage_config): self.protocol = storage_config.protocol self.path = storage_config.path self.options = storage_config.options - self.read_only = None - self.last_mount_attempt = 0 - self.mount_attempts = 0 + self.read_only: bool | None = None + self.last_mount_attempt: float = 0 + self.mount_attempts: int = 0 def __str__(self): res = f"storage {self.id}" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index c486d8a9..1d94c3d3 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -24,6 +24,7 @@ user-agents = "^2.2.0" httpx = "^0.24.1" requests = "^2.31.0" gunicorn = "^20.1.0" +granian = "^1.0.1" [tool.poetry.dev-dependencies] pytest = "^7.0" diff --git a/backend/server/clientinfo.py b/backend/server/clientinfo.py index 3b40815b..8744b3d5 100644 --- a/backend/server/clientinfo.py +++ b/backend/server/clientinfo.py @@ -1,5 +1,3 @@ -import contextlib -import ipaddress import os import geoip2 @@ -8,6 +6,8 @@ from fastapi import Request from pydantic import BaseModel, Field +from server.utils import is_internal_ip + class LocationInfo(BaseModel): country: str = Field(None, title="Country") @@ -29,11 +29,12 @@ class ClientInfo(BaseModel): def get_real_ip(request: Request) -> str: - xff = request.headers.get("x-forwarded-for", request.client.host) + host_ip = request.client.host if request.client else "127.0.0.1" + xff = request.headers.get("x-forwarded-for", host_ip) return xff.split(",")[0].strip() -def geo_lookup(ip: str): +def geo_lookup(ip: str) -> LocationInfo | None: geoip_db_path = None # TODO if geoip_db_path is None: @@ -48,23 +49,11 @@ def geo_lookup(ip: str): except geoip2.errors.AddressNotFoundError: return None - return LocationInfo( - country=response.country.name, - subdivision=response.subdivisions.most_specific.name, - city=response.city.name, - ) - return None - - -def is_internal_ip(ip: str) -> bool: - with contextlib.suppress(ValueError): - if ipaddress.IPv4Address(ip).is_private: - return True - - with contextlib.suppress(ValueError): - if ipaddress.IPv6Address(ip).is_private: - return True - return False + return LocationInfo( + country=response.country.name, + subdivision=response.subdivisions.most_specific.name, + city=response.city.name, + ) def get_ua_data(request) -> AgentInfo | None: diff --git a/backend/server/dependencies.py b/backend/server/dependencies.py index cdce3e78..7fcf217a 100644 --- a/backend/server/dependencies.py +++ b/backend/server/dependencies.py @@ -7,9 +7,16 @@ from server.utils import parse_access_token -async def access_token(authorization: str = Header(None)) -> str | None: - """Parse and return an access token provided in the request headers.""" - access_token = parse_access_token(authorization) +async def access_token( + authorization: str | None = Header(None), + token: str | None = Query(None), +) -> str | None: + """Parse and return an access token. + + Access token may be provided either in the Authorization header + or in the query parameters. + """ + access_token = token or parse_access_token(authorization or "") if not access_token: return None return access_token @@ -37,18 +44,6 @@ async def request_initiator(x_client_id: str | None = Header(None)) -> str | Non RequestInitiator = Annotated[str, Depends(request_initiator)] -async def current_user_query(token: str = Query(None)) -> nebula.User: - if token is None: - raise nebula.UnauthorizedException("No access token provided") - session = await Session.check(token, None) - if session is None: - raise nebula.UnauthorizedException("Invalid access token") - return nebula.User(meta=session.user) - - -CurrentUserInQuery = Annotated[nebula.User, Depends(current_user_query)] - - async def current_user( request: Request, access_token: AccessToken, diff --git a/backend/server/endpoints.py b/backend/server/endpoints.py index fee326cf..98740cbb 100644 --- a/backend/server/endpoints.py +++ b/backend/server/endpoints.py @@ -1,8 +1,10 @@ +import inspect import os from typing import Generator import fastapi from nxtools import slugify +from pydantic import BaseModel import nebula from nebula.common import classes_from_module, import_module @@ -47,6 +49,7 @@ def find_api_endpoints() -> Generator[APIRequest, None, None]: module = import_module(module_name, module_path) except ImportError: nebula.log.traceback(f"Failed to load endpoint {module_name}") + continue # Find API endpoints in module and yield them @@ -70,16 +73,34 @@ def install_endpoints(app: fastapi.FastAPI): nebula.log.warn(f"Duplicate endpoint name {endpoint.name}") continue + if not hasattr(endpoint, "handle"): + nebula.log.warn(f"Endpoint {endpoint.name} doesn't have a handle method") + continue + + if not callable(endpoint.handle): # type: ignore + nebula.log.warn(f"Endpoint {endpoint.name} handle is not callable") + continue + + # use inspect to get the return type of the handle method + # this is used to determine the response model + + sig = inspect.signature(endpoint.handle) # type: ignore + if sig.return_annotation is not inspect.Signature.empty: + response_model = sig.return_annotation + else: + response_model = None + + # + # Set the endpoint path and name + # + endpoint_names.add(endpoint.name) route = endpoint.path or f"/api/{endpoint.name}" nebula.log.trace("Adding endpoint", route) additional_params = {} - if endpoint.response_model is None: - additional_params["response_class"] = fastapi.Response - else: - additional_params["response_model"] = endpoint.response_model + if isinstance(response_model, BaseModel): additional_params["response_model_exclude_none"] = endpoint.exclude_none if isinstance(endpoint.__doc__, str): diff --git a/backend/server/server.py b/backend/server/server.py index f630d58e..0957c216 100644 --- a/backend/server/server.py +++ b/backend/server/server.py @@ -118,6 +118,8 @@ async def catchall_exception_handler( @app.websocket("/ws") async def ws_endpoint(websocket: WebSocket) -> None: client = await messaging.join(websocket) + if client is None: + return try: while True: message = await client.receive() @@ -125,15 +127,11 @@ async def ws_endpoint(websocket: WebSocket) -> None: continue if message["topic"] == "auth": - await client.authorize( - message.get("token"), - topics=message.get("subscribe", []), - ) - # if client.user_name: - # nebula.log.trace(f"{client.user_name} connected") + token = message.get("token") + subscribe = message.get("subscribe", []) + if token: + await client.authorize(token, subscribe) except WebSocketDisconnect: - # if client.user_name: - # nebula.log.trace(f"{client.user_name} disconnected") with contextlib.suppress(KeyError): del messaging.clients[client.id] diff --git a/backend/server/session.py b/backend/server/session.py index 8c4179bf..0cc743db 100644 --- a/backend/server/session.py +++ b/backend/server/session.py @@ -10,15 +10,7 @@ from nebula.common import create_hash, json_dumps, json_loads from nebula.exceptions import LoginFailedException from server.clientinfo import ClientInfo, get_client_info, get_real_ip - - -def is_local_ip(ip: str) -> bool: - return ( - ip.startswith("127.") - or ip.startswith("10.") - or ip.startswith("192.168.") - or ip.startswith("172.") - ) +from server.utils import is_internal_ip class SessionModel(BaseModel): @@ -65,7 +57,7 @@ async def check( await nebula.redis.set(cls.ns, token, session.json()) else: real_ip = get_real_ip(request) - if not is_local_ip(real_ip) and session.client_info.ip != real_ip: + if not is_internal_ip(real_ip) and session.client_info.ip != real_ip: nebula.log.warning( "Session IP mismatch. " f"Stored: {session.client_info.ip}, current: {real_ip}" @@ -94,7 +86,7 @@ async def create( client_info = get_client_info(request) if request else None if client_info: - if user["local_network_only"] and not is_local_ip(client_info.ip): + if user["local_network_only"] and not is_internal_ip(client_info.ip): raise LoginFailedException("You can only log in from local network") token = create_hash() @@ -142,7 +134,8 @@ async def list( from the database. """ - async for _session_id, data in nebula.redis.iterate(cls.ns): + async for _, data in nebula.redis.iterate(cls.ns): + assert isinstance(data, str) session = SessionModel(**json_loads(data)) if cls.is_expired(session): nebula.log.info( diff --git a/backend/server/utils.py b/backend/server/utils.py index 4a4fd0c9..62f8d508 100644 --- a/backend/server/utils.py +++ b/backend/server/utils.py @@ -1,3 +1,7 @@ +import contextlib +import ipaddress + + def parse_access_token(authorization: str) -> str | None: """Parse an authorization header value. @@ -16,3 +20,15 @@ def parse_access_token(authorization: str) -> str | None: if len(token) != 64: return None return token + + +def is_internal_ip(ip: str) -> bool: + """Return true if the given IP address is private""" + with contextlib.suppress(ValueError): + if ipaddress.IPv4Address(ip).is_private: + return True + + with contextlib.suppress(ValueError): + if ipaddress.IPv6Address(ip).is_private: + return True + return False diff --git a/backend/server/websocket.py b/backend/server/websocket.py index d5ba0938..ec026415 100644 --- a/backend/server/websocket.py +++ b/backend/server/websocket.py @@ -38,12 +38,6 @@ async def authorize(self, access_token: str, topics: list[str]) -> bool: self.topics = [*topics, *ALWAYS_SUBSCRIBE] if "*" not in topics else ["*"] self.authorized = True self.user = nebula.User(meta=session_data.user) - # logging.info( - # "Authorized connection", - # session_data.user["login"], - # "topics:", - # self.topics, - # ) return True return False @@ -59,7 +53,7 @@ async def send(self, message: dict[str, Any], auth_only: bool = True): except Exception as e: nebula.log.trace("WS: Error sending message", e) - async def receive(self): + async def receive(self) -> dict[str, Any] | None: data = await self.sock.receive_text() try: message = json_loads(data) @@ -88,10 +82,10 @@ def initialize(self) -> None: self.clients: dict[str, Client] = {} self.error_rate_data = [] - async def join(self, websocket: WebSocket): + async def join(self, websocket: WebSocket) -> Client | None: if not self.is_running: await websocket.close() - return + return None await websocket.accept() client = Client(websocket) self.clients[client.id] = client diff --git a/backend/setup/metatypes.py b/backend/setup/metatypes.py index 69f2d208..ee50a227 100644 --- a/backend/setup/metatypes.py +++ b/backend/setup/metatypes.py @@ -1,11 +1,12 @@ import json import os +from typing import Any async def setup_metatypes(meta_types, db): languages = ["en", "cs"] - aliases = {} + aliases: dict[str, dict[str, Any]] = {} for lang in languages: aliases[lang] = {} trans_table_fname = os.path.join("schema", f"meta-aliases-{lang}.json") diff --git a/backend/setup/settings.py b/backend/setup/settings.py index 073ccc39..83166794 100644 --- a/backend/setup/settings.py +++ b/backend/setup/settings.py @@ -23,7 +23,7 @@ from setup.defaults.views import VIEWS from setup.metatypes import setup_metatypes -TEMPLATE = { +TEMPLATE: dict[str, Any] = { "actions": ACTIONS, "channels": CHANNELS, "folders": FOLDERS, @@ -53,6 +53,7 @@ def load_overrides(): log.info(f"Found overrides for {key}") if isinstance(override, dict) and isinstance(TEMPLATE[key], dict): + assert hasattr(TEMPLATE[key], "update") TEMPLATE[key].update(override) elif isinstance(override, list) and isinstance(TEMPLATE[key], list): TEMPLATE[key] = override @@ -61,7 +62,6 @@ def load_overrides(): async def setup_settings(db): - load_overrides() log.info("Applying system settings") @@ -147,7 +147,7 @@ async def setup_settings(db): # Setup classifications used_urns = set() - for _meta_type, mset in TEMPLATE["meta_types"].items(): + for mset in TEMPLATE["meta_types"].values(): if mset.get("cs"): used_urns.add(mset["cs"]) diff --git a/frontend/src/containers/Browser/Browser.jsx b/frontend/src/containers/Browser/Browser.jsx index 42c0d2a0..7da822a1 100644 --- a/frontend/src/containers/Browser/Browser.jsx +++ b/frontend/src/containers/Browser/Browser.jsx @@ -65,7 +65,7 @@ const BrowserTable = () => { const dataRef = useRef(data) useEffect(() => { - dataRef.current = data; + dataRef.current = data }, [data]) const loadData = () => { @@ -101,20 +101,23 @@ const BrowserTable = () => { const debouncingLoadData = useCallback(debounce(loadData, 100), [loadData]) - const handlePubSub = useCallback((topic, message) => { - if (topic !== 'objects_changed') return - if (message.object_type !== 'asset') return - let changed = false - for (const obj of message.objects) { - if (dataRef.current.find((row) => row.id === obj)) { - changed = true; - break; + const handlePubSub = useCallback( + (topic, message) => { + if (topic !== 'objects_changed') return + if (message.object_type !== 'asset') return + let changed = false + for (const obj of message.objects) { + if (dataRef.current.find((row) => row.id === obj)) { + changed = true + break + } } - } - if (changed){ - debouncingLoadData() - } - }, [loadData]) + if (changed) { + debouncingLoadData() + } + }, + [loadData] + ) useEffect(() => { const token = PubSub.subscribe('objects_changed', handlePubSub)