From 0e7e9179e644a9ea4c3378382874e62fa3362cdf Mon Sep 17 00:00:00 2001 From: Ralf Grubenmann Date: Mon, 7 Oct 2024 11:41:31 +0200 Subject: [PATCH] refactor: add validation to project, storage, repo and session blueprints (#347) --- components/renku_data_services/authz/authz.py | 2 +- .../renku_data_services/authz/models.py | 2 +- .../renku_data_services/base_api/auth.py | 24 ------- components/renku_data_services/crc/apispec.py | 2 +- .../renku_data_services/project/api.spec.yaml | 2 +- .../project/apispec_base.py | 9 ++- .../renku_data_services/project/blueprints.py | 42 ++++++------- components/renku_data_services/project/db.py | 1 - components/renku_data_services/project/orm.py | 2 +- .../repositories/blueprints.py | 8 +-- .../renku_data_services/secrets/apispec.py | 2 +- .../session/apispec_base.py | 26 +++++++- .../renku_data_services/session/blueprints.py | 48 ++++++-------- components/renku_data_services/session/db.py | 17 +++-- .../renku_data_services/session/models.py | 2 - components/renku_data_services/session/orm.py | 20 +++--- .../renku_data_services/storage/api.spec.yaml | 26 +++++--- .../renku_data_services/storage/apispec.py | 20 +++--- .../renku_data_services/storage/blueprints.py | 63 +++++++++---------- .../data_api/test_projects.py | 2 +- .../data_api/test_storage.py | 10 +-- .../authz/test_authorization.py | 22 ++++--- 22 files changed, 178 insertions(+), 174 deletions(-) diff --git a/components/renku_data_services/authz/authz.py b/components/renku_data_services/authz/authz.py index 6df3170c6..30f3f996f 100644 --- a/components/renku_data_services/authz/authz.py +++ b/components/renku_data_services/authz/authz.py @@ -586,7 +586,7 @@ async def _get_members_helper( member = Member( user_id=response.relationship.subject.object.object_id, role=member_role, - resource_id=response.relationship.resource.object_id, + resource_id=ULID.from_str(response.relationship.resource.object_id), ) yield member diff --git a/components/renku_data_services/authz/models.py b/components/renku_data_services/authz/models.py index 64399ea97..cfd93536a 100644 --- a/components/renku_data_services/authz/models.py +++ b/components/renku_data_services/authz/models.py @@ -74,7 +74,7 @@ def with_group(self, group_id: ULID) -> "Member": class Member(UnsavedMember): """Member stored in the database.""" - resource_id: str | ULID + resource_id: ULID class Change(Enum): diff --git a/components/renku_data_services/base_api/auth.py b/components/renku_data_services/base_api/auth.py index f468296f4..16b76b09d 100644 --- a/components/renku_data_services/base_api/auth.py +++ b/components/renku_data_services/base_api/auth.py @@ -71,30 +71,6 @@ async def decorated_function(request: Request, *args: _P.args, **kwargs: _P.kwar return decorator -def validate_path_project_id( - f: Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]], -) -> Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]]: - """Decorator for a Sanic handler that validates the project_id path parameter.""" - _path_project_id_regex = re.compile(r"^[A-Za-z0-9]{26}$") - - @wraps(f) - async def decorated_function(request: Request, *args: _P.args, **kwargs: _P.kwargs) -> _T: - project_id = cast(str | None, kwargs.get("project_id")) - if not project_id: - raise errors.ProgrammingError( - message="Could not find 'project_id' in the keyword arguments for the handler in order to validate it." - ) - if not _path_project_id_regex.match(project_id): - raise errors.ValidationError( - message=f"The 'project_id' path parameter {project_id} does not match the required " - f"regex {_path_project_id_regex}" - ) - - return await f(request, *args, **kwargs) - - return decorated_function - - def validate_path_user_id( f: Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]], ) -> Callable[Concatenate[Request, _P], Coroutine[Any, Any, _T]]: diff --git a/components/renku_data_services/crc/apispec.py b/components/renku_data_services/crc/apispec.py index c9d2dd497..199734dd8 100644 --- a/components/renku_data_services/crc/apispec.py +++ b/components/renku_data_services/crc/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-10-18T11:06:20+00:00 +# timestamp: 2024-08-20T07:15:17+00:00 from __future__ import annotations diff --git a/components/renku_data_services/project/api.spec.yaml b/components/renku_data_services/project/api.spec.yaml index a814afc90..d68c93f80 100644 --- a/components/renku_data_services/project/api.spec.yaml +++ b/components/renku_data_services/project/api.spec.yaml @@ -143,7 +143,7 @@ paths: $ref: "#/components/responses/Error" tags: - projects - /projects/{namespace}/{slug}: + /namespaces/{namespace}/projects/{slug}: get: summary: Get a project by namespace and project slug parameters: diff --git a/components/renku_data_services/project/apispec_base.py b/components/renku_data_services/project/apispec_base.py index c888c3ba5..476c07927 100644 --- a/components/renku_data_services/project/apispec_base.py +++ b/components/renku_data_services/project/apispec_base.py @@ -1,6 +1,7 @@ """Base models for API specifications.""" -from pydantic import BaseModel +from pydantic import BaseModel, field_validator +from ulid import ULID class BaseAPISpec(BaseModel): @@ -13,3 +14,9 @@ class Config: # NOTE: By default the pydantic library does not use python for regex but a rust crate # this rust crate does not support lookahead regex syntax but we need it in this component regex_engine = "python-re" + + @field_validator("id", mode="before", check_fields=False) + @classmethod + def serialize_id(cls, id: str | ULID) -> str: + """Custom serializer that can handle ULIDs.""" + return str(id) diff --git a/components/renku_data_services/project/blueprints.py b/components/renku_data_services/project/blueprints.py index 43200cd92..c72c2f8d7 100644 --- a/components/renku_data_services/project/blueprints.py +++ b/components/renku_data_services/project/blueprints.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Any -from sanic import HTTPResponse, Request, json +from sanic import HTTPResponse, Request from sanic.response import JSONResponse from sanic_ext import validate from ulid import ULID @@ -13,7 +13,6 @@ from renku_data_services.base_api.auth import ( authenticate, only_authenticated, - validate_path_project_id, validate_path_user_id, ) from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint @@ -94,7 +93,7 @@ async def _get_one( headers = {"ETag": project.etag} if project.etag is not None else None return validated_json(apispec.Project, self._dump_project(project), headers=headers) - return "/projects/", ["GET"], _get_one + return "/projects/", ["GET"], _get_one def get_one_by_namespace_slug(self) -> BlueprintFactoryResponse: """Get a specific project by namespace/slug.""" @@ -112,34 +111,32 @@ async def _get_one_by_namespace_slug( headers = {"ETag": project.etag} if project.etag is not None else None return validated_json(apispec.Project, self._dump_project(project), headers=headers) - return "/projects//", ["GET"], _get_one_by_namespace_slug + return "/namespaces//projects/", ["GET"], _get_one_by_namespace_slug def delete(self) -> BlueprintFactoryResponse: """Delete a specific project.""" @authenticate(self.authenticator) @only_authenticated - @validate_path_project_id - async def _delete(_: Request, user: base_models.APIUser, project_id: str) -> HTTPResponse: - await self.project_repo.delete_project(user=user, project_id=ULID.from_str(project_id)) + async def _delete(_: Request, user: base_models.APIUser, project_id: ULID) -> HTTPResponse: + await self.project_repo.delete_project(user=user, project_id=project_id) return HTTPResponse(status=204) - return "/projects/", ["DELETE"], _delete + return "/projects/", ["DELETE"], _delete def patch(self) -> BlueprintFactoryResponse: """Partially update a specific project.""" @authenticate(self.authenticator) @only_authenticated - @validate_path_project_id @if_match_required @validate(json=apispec.ProjectPatch) async def _patch( - _: Request, user: base_models.APIUser, project_id: str, body: apispec.ProjectPatch, etag: str + _: Request, user: base_models.APIUser, project_id: ULID, body: apispec.ProjectPatch, etag: str ) -> JSONResponse: project_patch = validate_project_patch(body) project_update = await self.project_repo.update_project( - user=user, project_id=ULID.from_str(project_id), etag=etag, patch=project_patch + user=user, project_id=project_id, etag=etag, patch=project_patch ) if not isinstance(project_update, project_models.ProjectUpdate): @@ -151,15 +148,14 @@ async def _patch( updated_project = project_update.new return validated_json(apispec.Project, self._dump_project(updated_project)) - return "/projects/", ["PATCH"], _patch + return "/projects/", ["PATCH"], _patch def get_all_members(self) -> BlueprintFactoryResponse: """List all project members.""" @authenticate(self.authenticator) - @validate_path_project_id - async def _get_all_members(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse: - members = await self.project_member_repo.get_members(user, ULID.from_str(project_id)) + async def _get_all_members(_: Request, user: base_models.APIUser, project_id: ULID) -> JSONResponse: + members = await self.project_member_repo.get_members(user, project_id) users = [] @@ -179,35 +175,33 @@ async def _get_all_members(_: Request, user: base_models.APIUser, project_id: st ).model_dump(exclude_none=True, mode="json") users.append(user_with_id) - return json(users) + return validated_json(apispec.ProjectMemberListResponse, users) - return "/projects//members", ["GET"], _get_all_members + return "/projects//members", ["GET"], _get_all_members def update_members(self) -> BlueprintFactoryResponse: """Update or add project members.""" @authenticate(self.authenticator) - @validate_path_project_id @validate_body_root_model(json=apispec.ProjectMemberListPatchRequest) async def _update_members( - _: Request, user: base_models.APIUser, project_id: str, body: apispec.ProjectMemberListPatchRequest + _: Request, user: base_models.APIUser, project_id: ULID, body: apispec.ProjectMemberListPatchRequest ) -> HTTPResponse: members = [Member(Role(i.role.value), i.id, project_id) for i in body.root] - await self.project_member_repo.update_members(user, ULID.from_str(project_id), members) + await self.project_member_repo.update_members(user, project_id, members) return HTTPResponse(status=200) - return "/projects//members", ["PATCH"], _update_members + return "/projects//members", ["PATCH"], _update_members def delete_member(self) -> BlueprintFactoryResponse: """Delete a specific project.""" @authenticate(self.authenticator) - @validate_path_project_id @validate_path_user_id async def _delete_member( - _: Request, user: base_models.APIUser, project_id: str, member_id: str + _: Request, user: base_models.APIUser, project_id: ULID, member_id: str ) -> HTTPResponse: - await self.project_member_repo.delete_members(user, ULID.from_str(project_id), [member_id]) + await self.project_member_repo.delete_members(user, project_id, [member_id]) return HTTPResponse(status=204) return "/projects//members/", ["DELETE"], _delete_member diff --git a/components/renku_data_services/project/db.py b/components/renku_data_services/project/db.py index 16f1acaa8..88a72d7cc 100644 --- a/components/renku_data_services/project/db.py +++ b/components/renku_data_services/project/db.py @@ -221,7 +221,6 @@ async def update_project( session: AsyncSession | None = None, ) -> models.ProjectUpdate: """Update a project entry.""" - project_id_str: str = str(project_id) if not session: raise errors.ProgrammingError(message="A database session is required") result = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)) diff --git a/components/renku_data_services/project/orm.py b/components/renku_data_services/project/orm.py index 4b936a403..de6139d9d 100644 --- a/components/renku_data_services/project/orm.py +++ b/components/renku_data_services/project/orm.py @@ -81,7 +81,7 @@ class ProjectRepositoryORM(BaseORM): id: Mapped[int] = mapped_column("id", Integer, Identity(always=True), primary_key=True, default=None, init=False) url: Mapped[str] = mapped_column("url", String(2000)) - project_id: Mapped[Optional[str]] = mapped_column( + project_id: Mapped[Optional[ULID]] = mapped_column( ForeignKey("projects.id", ondelete="CASCADE"), default=None, index=True ) project: Mapped[Optional[ProjectORM]] = relationship(back_populates="repositories", default=None, repr=False) diff --git a/components/renku_data_services/repositories/blueprints.py b/components/renku_data_services/repositories/blueprints.py index 48029b1b1..c07e22c67 100644 --- a/components/renku_data_services/repositories/blueprints.py +++ b/components/renku_data_services/repositories/blueprints.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from urllib.parse import unquote -from sanic import HTTPResponse, Request, json +from sanic import HTTPResponse, Request from sanic.response import JSONResponse import renku_data_services.base_models as base_models @@ -11,6 +11,7 @@ from renku_data_services.base_api.auth import authenticate_2 from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint from renku_data_services.base_api.etag import extract_if_none_match +from renku_data_services.base_models.validation import validated_json from renku_data_services.repositories import apispec from renku_data_services.repositories.apispec_base import RepositoryParams from renku_data_services.repositories.db import GitRepositoriesRepository @@ -53,10 +54,7 @@ async def _get_one_repository( if result.repository_metadata and result.repository_metadata.etag is not None else None ) - return json( - apispec.RepositoryProviderMatch.model_validate(result).model_dump(exclude_none=True, mode="json"), - headers=headers, - ) + return validated_json(apispec.RepositoryProviderMatch, result, headers=headers) return "/repositories/", ["GET"], _get_one_repository diff --git a/components/renku_data_services/secrets/apispec.py b/components/renku_data_services/secrets/apispec.py index 5cc27a959..9238dbede 100644 --- a/components/renku_data_services/secrets/apispec.py +++ b/components/renku_data_services/secrets/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-13T13:29:49+00:00 +# timestamp: 2024-08-20T07:15:21+00:00 from __future__ import annotations diff --git a/components/renku_data_services/session/apispec_base.py b/components/renku_data_services/session/apispec_base.py index d91e73fb9..14eaa0a01 100644 --- a/components/renku_data_services/session/apispec_base.py +++ b/components/renku_data_services/session/apispec_base.py @@ -5,6 +5,8 @@ from pydantic import BaseModel, field_validator from ulid import ULID +from renku_data_services.session import models + class BaseAPISpec(BaseModel): """Base API specification.""" @@ -14,12 +16,34 @@ class Config: from_attributes = True - @field_validator("id", "project_id", mode="before", check_fields=False) + @field_validator("id", mode="before", check_fields=False) @classmethod def serialize_id(cls, id: str | ULID) -> str: """Custom serializer that can handle ULIDs.""" return str(id) + @field_validator("project_id", mode="before", check_fields=False) + @classmethod + def serialize_project_id(cls, project_id: str | ULID) -> str: + """Custom serializer that can handle ULIDs.""" + return str(project_id) + + @field_validator("environment_id", mode="before", check_fields=False) + @classmethod + def serialize_environment_id(cls, environment_id: str | ULID | None) -> str | None: + """Custom serializer that can handle ULIDs.""" + if environment_id is None: + return None + return str(environment_id) + + @field_validator("environment_kind", mode="before", check_fields=False) + @classmethod + def serialize_environment_kind(cls, environment_kind: models.EnvironmentKind | str) -> str: + """Custom serializer that can handle ULIDs.""" + if isinstance(environment_kind, models.EnvironmentKind): + return environment_kind.value + return environment_kind + @field_validator("working_directory", "mount_directory", check_fields=False, mode="before") @classmethod def convert_path_to_string(cls, val: str | PurePosixPath) -> str: diff --git a/components/renku_data_services/session/blueprints.py b/components/renku_data_services/session/blueprints.py index 75fbf25a4..772630aba 100644 --- a/components/renku_data_services/session/blueprints.py +++ b/components/renku_data_services/session/blueprints.py @@ -3,14 +3,15 @@ from dataclasses import dataclass from pathlib import PurePosixPath -from sanic import HTTPResponse, Request, json +from sanic import HTTPResponse, Request from sanic.response import JSONResponse from sanic_ext import validate from ulid import ULID import renku_data_services.base_models as base_models -from renku_data_services.base_api.auth import authenticate, validate_path_project_id +from renku_data_services.base_api.auth import authenticate, only_authenticated from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint +from renku_data_services.base_models.validation import validated_json from renku_data_services.session import apispec, models from renku_data_services.session.db import SessionRepository @@ -27,9 +28,7 @@ def get_all(self) -> BlueprintFactoryResponse: async def _get_all(_: Request) -> JSONResponse: environments = await self.session_repo.get_environments() - return json( - [apispec.Environment.model_validate(e).model_dump(exclude_none=True, mode="json") for e in environments] - ) + return validated_json(apispec.EnvironmentList, environments) return "/environments", ["GET"], _get_all @@ -38,7 +37,7 @@ def get_one(self) -> BlueprintFactoryResponse: async def _get_one(_: Request, environment_id: ULID) -> JSONResponse: environment = await self.session_repo.get_environment(environment_id=environment_id) - return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json")) + return validated_json(apispec.Environment, environment) return "/environments/", ["GET"], _get_one @@ -46,6 +45,7 @@ def post(self) -> BlueprintFactoryResponse: """Create a new session environment.""" @authenticate(self.authenticator) + @only_authenticated @validate(json=apispec.EnvironmentPost) async def _post(_: Request, user: base_models.APIUser, body: apispec.EnvironmentPost) -> JSONResponse: unsaved_environment = models.UnsavedEnvironment( @@ -63,7 +63,7 @@ async def _post(_: Request, user: base_models.APIUser, body: apispec.Environment args=body.args, ) environment = await self.session_repo.insert_environment(user=user, new_environment=unsaved_environment) - return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"), 201) + return validated_json(apispec.Environment, environment, 201) return "/environments", ["POST"], _post @@ -71,6 +71,7 @@ def patch(self) -> BlueprintFactoryResponse: """Partially update a specific session environment.""" @authenticate(self.authenticator) + @only_authenticated @validate(json=apispec.EnvironmentPatch) async def _patch( _: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch @@ -79,7 +80,7 @@ async def _patch( environment = await self.session_repo.update_environment( user=user, environment_id=environment_id, **body_dict ) - return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json")) + return validated_json(apispec.Environment, environment) return "/environments/", ["PATCH"], _patch @@ -87,6 +88,7 @@ def delete(self) -> BlueprintFactoryResponse: """Delete a specific session environment.""" @authenticate(self.authenticator) + @only_authenticated async def _delete(_: Request, user: base_models.APIUser, environment_id: ULID) -> HTTPResponse: await self.session_repo.delete_environment(user=user, environment_id=environment_id) return HTTPResponse(status=204) @@ -107,12 +109,7 @@ def get_all(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse: launchers = await self.session_repo.get_launchers(user=user) - return json( - [ - apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json") - for item in launchers - ] - ) + return validated_json(apispec.SessionLaunchersList, launchers) return "/session_launchers", ["GET"], _get_all @@ -122,7 +119,7 @@ def get_one(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) async def _get_one(_: Request, user: base_models.APIUser, launcher_id: ULID) -> JSONResponse: launcher = await self.session_repo.get_launcher(user=user, launcher_id=launcher_id) - return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json")) + return validated_json(apispec.SessionLauncher, launcher) return "/session_launchers/", ["GET"], _get_one @@ -130,6 +127,7 @@ def post(self) -> BlueprintFactoryResponse: """Create a new session launcher.""" @authenticate(self.authenticator) + @only_authenticated @validate(json=apispec.SessionLauncherPost) async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLauncherPost) -> JSONResponse: environment: str | models.UnsavedEnvironment @@ -158,9 +156,7 @@ async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLaun resource_class_id=body.resource_class_id, ) launcher = await self.session_repo.insert_launcher(user=user, new_launcher=new_launcher) - return json( - apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"), 201 - ) + return validated_json(apispec.SessionLauncher, launcher, 201) return "/session_launchers", ["POST"], _post @@ -168,6 +164,7 @@ def patch(self) -> BlueprintFactoryResponse: """Partially update a specific session launcher.""" @authenticate(self.authenticator) + @only_authenticated @validate(json=apispec.SessionLauncherPatch) async def _patch( _: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch @@ -201,7 +198,7 @@ async def _patch( launcher = await self.session_repo.update_launcher( user=user, launcher_id=launcher_id, new_custom_environment=new_env, session=session, **body_dict ) - return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json")) + return validated_json(apispec.SessionLauncher, launcher) return "/session_launchers/", ["PATCH"], _patch @@ -209,6 +206,7 @@ def delete(self) -> BlueprintFactoryResponse: """Delete a specific session launcher.""" @authenticate(self.authenticator) + @only_authenticated async def _delete(_: Request, user: base_models.APIUser, launcher_id: ULID) -> HTTPResponse: await self.session_repo.delete_launcher(user=user, launcher_id=launcher_id) return HTTPResponse(status=204) @@ -219,14 +217,8 @@ def get_project_launchers(self) -> BlueprintFactoryResponse: """Get all launchers belonging to a project.""" @authenticate(self.authenticator) - @validate_path_project_id - async def _get_launcher(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse: + async def _get_launcher(_: Request, user: base_models.APIUser, project_id: ULID) -> JSONResponse: launchers = await self.session_repo.get_project_launchers(user=user, project_id=project_id) - return json( - [ - apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json") - for item in launchers - ] - ) + return validated_json(apispec.SessionLaunchersList, launchers) - return "/projects//session_launchers", ["GET"], _get_launcher + return "/projects//session_launchers", ["GET"], _get_launcher diff --git a/components/renku_data_services/session/db.py b/components/renku_data_services/session/db.py index 417820e40..b7bd592d0 100644 --- a/components/renku_data_services/session/db.py +++ b/components/renku_data_services/session/db.py @@ -4,7 +4,6 @@ from collections.abc import Callable from contextlib import AbstractAsyncContextManager, nullcontext -from datetime import UTC, datetime from typing import Any from sqlalchemy import select @@ -46,7 +45,7 @@ async def get_environment(self, environment_id: ULID) -> models.Environment: async with self.session_maker() as session: res = await session.scalars( select(schemas.EnvironmentORM) - .where(schemas.EnvironmentORM.id == str(environment_id)) + .where(schemas.EnvironmentORM.id == environment_id) .where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL.value) ) environment = res.one_or_none() @@ -69,7 +68,6 @@ async def __insert_environment( environment = schemas.EnvironmentORM( name=new_environment.name, created_by_id=user.id, - creation_date=datetime.now(UTC), description=new_environment.description, container_image=new_environment.container_image, default_url=new_environment.default_url, @@ -99,6 +97,8 @@ async def insert_environment( async with self.session_maker() as session, session.begin(): env = await self.__insert_environment(user, session, new_environment) + await session.flush() + await session.refresh(env) return env.dump() async def __update_environment( @@ -157,7 +157,7 @@ async def delete_environment(self, user: base_models.APIUser, environment_id: UL async with self.session_maker() as session, session.begin(): res = await session.scalars( select(schemas.EnvironmentORM) - .where(schemas.EnvironmentORM.id == str(environment_id)) + .where(schemas.EnvironmentORM.id == environment_id) .where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL.value) ) environment = res.one_or_none() @@ -182,11 +182,9 @@ async def get_launchers(self, user: base_models.APIUser) -> list[models.SessionL launcher = res.all() return [item.dump() for item in launcher] - async def get_project_launchers(self, user: base_models.APIUser, project_id: str) -> list[models.SessionLauncher]: + async def get_project_launchers(self, user: base_models.APIUser, project_id: ULID) -> list[models.SessionLauncher]: """Get all session launchers in a project from the database.""" - authorized = await self.project_authz.has_permission( - user, ResourceType.project, ULID.from_str(project_id), Scope.READ - ) + authorized = await self.project_authz.has_permission(user, ResourceType.project, project_id, Scope.READ) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id}' does not exist or you do not have access to it." @@ -285,7 +283,6 @@ async def insert_launcher( launcher = schemas.SessionLauncherORM( name=new_launcher.name, created_by_id=user.id, - creation_date=datetime.now(UTC), description=new_launcher.description, project_id=new_launcher.project_id, environment_id=environment_id, @@ -365,6 +362,8 @@ async def update_launcher( env_payload = kwargs.get("environment", {}) await self.__update_launcher_environment(user, launcher, session, new_custom_environment, **env_payload) + await session.flush() + await session.refresh(launcher) return launcher.dump() async def __update_launcher_environment( diff --git a/components/renku_data_services/session/models.py b/components/renku_data_services/session/models.py index 6dcff46c2..22d714e63 100644 --- a/components/renku_data_services/session/models.py +++ b/components/renku_data_services/session/models.py @@ -74,7 +74,6 @@ class Environment(BaseEnvironment): class BaseSessionLauncher: """Session launcher model.""" - id: ULID | None project_id: ULID name: str description: str | None @@ -86,7 +85,6 @@ class BaseSessionLauncher: class UnsavedSessionLauncher(BaseSessionLauncher): """Session launcher model that has not been persisted in the DB.""" - id: ULID | None = None environment: str | UnsavedEnvironment """When a string is passed for the environment it should be the ID of an existing environment.""" diff --git a/components/renku_data_services/session/orm.py b/components/renku_data_services/session/orm.py index 2a7cc855d..01ed3e59f 100644 --- a/components/renku_data_services/session/orm.py +++ b/components/renku_data_services/session/orm.py @@ -3,7 +3,7 @@ from datetime import datetime from pathlib import PurePosixPath -from sqlalchemy import JSON, DateTime, MetaData, String +from sqlalchemy import JSON, DateTime, MetaData, String, func from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column, relationship from sqlalchemy.schema import ForeignKey @@ -38,9 +38,6 @@ class EnvironmentORM(BaseORM): created_by_id: Mapped[str] = mapped_column("created_by_id", String()) """Id of the user who created the session environment.""" - creation_date: Mapped[datetime] = mapped_column("creation_date", DateTime(timezone=True)) - """Creation date and time.""" - description: Mapped[str | None] = mapped_column("description", String(500)) """Human-readable description of the session environment.""" @@ -59,6 +56,11 @@ class EnvironmentORM(BaseORM): args: Mapped[list[str] | None] = mapped_column("args", JSONVariant, nullable=True) command: Mapped[list[str] | None] = mapped_column("command", JSONVariant, nullable=True) + creation_date: Mapped[datetime] = mapped_column( + "creation_date", DateTime(timezone=True), default=func.now(), nullable=False + ) + """Creation date and time.""" + def dump(self) -> models.Environment: """Create a session environment model from the EnvironmentORM.""" return models.Environment( @@ -94,15 +96,17 @@ class SessionLauncherORM(BaseORM): created_by_id: Mapped[str] = mapped_column("created_by_id", String()) """Id of the user who created the session launcher.""" - creation_date: Mapped[datetime] = mapped_column("creation_date", DateTime(timezone=True)) - """Creation date and time.""" - description: Mapped[str | None] = mapped_column("description", String(500)) """Human-readable description of the session launcher.""" project: Mapped[ProjectORM] = relationship(init=False) environment: Mapped[EnvironmentORM] = relationship(init=False, lazy="joined") + creation_date: Mapped[datetime] = mapped_column( + "creation_date", DateTime(timezone=True), default=func.now(), nullable=False + ) + """Creation date and time.""" + project_id: Mapped[ULID] = mapped_column( "project_id", ForeignKey(ProjectORM.id, ondelete="CASCADE"), default=None, index=True ) @@ -130,7 +134,7 @@ def load(cls, launcher: models.SessionLauncher) -> "SessionLauncherORM": created_by_id=launcher.created_by, creation_date=launcher.creation_date, description=launcher.description, - project_id=ULID.from_str(launcher.project_id), + project_id=launcher.project_id, environment_id=launcher.environment.id, resource_class_id=launcher.resource_class_id, ) diff --git a/components/renku_data_services/storage/api.spec.yaml b/components/renku_data_services/storage/api.spec.yaml index 55ec8e9e9..0ff7658a3 100644 --- a/components/renku_data_services/storage/api.spec.yaml +++ b/components/renku_data_services/storage/api.spec.yaml @@ -14,7 +14,7 @@ paths: name: storage_id required: true schema: - $ref: "#/components/schemas/UlidId" + $ref: "#/components/schemas/Ulid" description: the id of the storage get: summary: get cloud storage details @@ -154,7 +154,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/RCloneConfig" + $ref: "#/components/schemas/RCloneConfigValidate" responses: "204": description: The configuration is valid @@ -223,7 +223,7 @@ components: project_id: oneOf: - $ref: "#/components/schemas/GitlabProjectId" - - $ref: "#/components/schemas/UlidId" + - $ref: "#/components/schemas/Ulid" required: - project_id example: @@ -238,6 +238,16 @@ components: nullable: true - type: boolean - type: object + RCloneConfigValidate: #this is the same as RCloneConfig but duplicated so a class gets generated + type: object + description: Dictionary of rclone key:value pairs (based on schema from '/storage_schema') + additionalProperties: + oneOf: + - type: integer + - type: string + nullable: true + - type: boolean + - type: object CloudStorageUrl: allOf: - $ref: "#/components/schemas/ProjectId" @@ -295,7 +305,7 @@ components: project_id: oneOf: - $ref: "#/components/schemas/GitlabProjectId" - - $ref: "#/components/schemas/UlidId" + - $ref: "#/components/schemas/Ulid" storage_type: $ref: "#/components/schemas/StorageType" name: @@ -319,7 +329,7 @@ components: - storage_id properties: storage_id: - $ref: "#/components/schemas/UlidId" + $ref: "#/components/schemas/Ulid" CloudStorageGet: type: object description: Get response for a cloud storage. Contains storage and information about fields that are required if the storage is private. @@ -414,12 +424,12 @@ components: type: string description: data type of option value. RClone has more options but they map to the ones listed here. enum: ["int", "bool", "string", "Time"] - UlidId: - description: ULID identifier of an object + Ulid: + description: ULID identifier type: string minLength: 26 maxLength: 26 - pattern: "^[A-Z0-9]+$" + pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive GitlabProjectId: description: Project id of a gitlab project (only int project id allowed, encoded as string for future-proofing) type: string diff --git a/components/renku_data_services/storage/apispec.py b/components/renku_data_services/storage/apispec.py index b748a586a..0580dda06 100644 --- a/components/renku_data_services/storage/apispec.py +++ b/components/renku_data_services/storage/apispec.py @@ -11,6 +11,12 @@ from renku_data_services.storage.apispec_base import BaseAPISpec +class RCloneConfigValidate( + RootModel[Optional[Dict[str, Union[int, Optional[str], bool, Dict[str, Any]]]]] +): + root: Optional[Dict[str, Union[int, Optional[str], bool, Dict[str, Any]]]] = None + + class Example(BaseAPISpec): value: Optional[str] = Field( None, description="a potential value for the option (think enum)" @@ -70,13 +76,13 @@ class RCloneOption(BaseAPISpec): ) -class UlidId(RootModel[str]): +class Ulid(RootModel[str]): root: str = Field( ..., - description="ULID identifier of an object", + description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]+$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) @@ -131,7 +137,7 @@ class StorageSchemaObscurePostRequest(BaseAPISpec): class ProjectId(BaseAPISpec): - project_id: Union[GitlabProjectId, UlidId] + project_id: Union[GitlabProjectId, Ulid] class CloudStorageUrl(ProjectId): @@ -174,7 +180,7 @@ class CloudStorage(ProjectId): class CloudStoragePatch(BaseAPISpec): - project_id: Optional[Union[GitlabProjectId, UlidId]] = None + project_id: Optional[Union[GitlabProjectId, Ulid]] = None storage_type: Optional[str] = Field( None, description="same as rclone prefix/ rclone config type. Ignored in requests, but returned in responses for convenience.", @@ -206,10 +212,10 @@ class CloudStoragePatch(BaseAPISpec): class CloudStorageWithId(CloudStorage): storage_id: str = Field( ..., - description="ULID identifier of an object", + description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]+$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) diff --git a/components/renku_data_services/storage/blueprints.py b/components/renku_data_services/storage/blueprints.py index 4ae20b1d1..6aa0a46c0 100644 --- a/components/renku_data_services/storage/blueprints.py +++ b/components/renku_data_services/storage/blueprints.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Any -from sanic import HTTPResponse, Request, empty, json +from sanic import HTTPResponse, Request, empty from sanic.response import JSONResponse from sanic_ext import validate from ulid import ULID @@ -13,6 +13,7 @@ from renku_data_services.base_api.auth import authenticate from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint from renku_data_services.base_api.misc import validate_query +from renku_data_services.base_models.validation import validated_json from renku_data_services.storage import apispec, models from renku_data_services.storage.db import StorageRepository from renku_data_services.storage.rclone import RCloneValidator @@ -48,7 +49,9 @@ async def _get( ) -> JSONResponse: storage = await self.storage_repo.get_storage(user=user, project_id=query.project_id) - return json([dump_storage_with_sensitive_fields(s, validator) for s in storage]) + return validated_json( + apispec.StorageGetResponse, [dump_storage_with_sensitive_fields(s, validator) for s in storage] + ) return "/storage", ["GET"], _get @@ -64,7 +67,7 @@ async def _get_one( ) -> JSONResponse: storage = await self.storage_repo.get_storage_by_id(storage_id, user=user) - return json(dump_storage_with_sensitive_fields(storage, validator)) + return validated_json(apispec.CloudStorageGet, dump_storage_with_sensitive_fields(storage, validator)) return "/storage/", ["GET"], _get_one @@ -95,7 +98,7 @@ async def _post(request: Request, user: base_models.APIUser, validator: RCloneVa validator.validate(storage.configuration.model_dump()) res = await self.storage_repo.insert_storage(storage=storage, user=user) - return json(dump_storage_with_sensitive_fields(res, validator), 201) + return validated_json(apispec.CloudStorageGet, dump_storage_with_sensitive_fields(res, validator), 201) return "/storage", ["POST"], _post @@ -129,7 +132,7 @@ async def _put( validator.validate(new_storage.configuration.model_dump()) body_dict = new_storage.model_dump() res = await self.storage_repo.update_storage(storage_id=storage_id, user=user, **body_dict) - return json(dump_storage_with_sensitive_fields(res, validator)) + return validated_json(apispec.CloudStorageGet, dump_storage_with_sensitive_fields(res, validator)) return "/storage/", ["PUT"], _put @@ -159,7 +162,7 @@ async def _patch( body_dict = body.model_dump(exclude_none=True) res = await self.storage_repo.update_storage(storage_id=storage_id, user=user, **body_dict) - return json(dump_storage_with_sensitive_fields(res, validator)) + return validated_json(apispec.CloudStorageGet, dump_storage_with_sensitive_fields(res, validator)) return "/storage/", ["PATCH"], _patch @@ -182,29 +185,19 @@ def get(self) -> BlueprintFactoryResponse: """Get cloud storage for a repository.""" async def _get(_: Request, validator: RCloneValidator) -> JSONResponse: - return json(validator.asdict()) + return validated_json(apispec.RCloneSchema, validator.asdict()) return "/storage_schema", ["GET"], _get def test_connection(self) -> BlueprintFactoryResponse: """Validate an RClone config.""" - async def _test_connection(request: Request, validator: RCloneValidator) -> HTTPResponse: - if not request.json: - raise errors.ValidationError(message="The request body is empty. Please provide a valid JSON object.") - if not isinstance(request.json, dict): - raise errors.ValidationError(message="The request body is not a valid JSON object.") - if not request.json.get("configuration"): - raise errors.ValidationError(message="No 'configuration' sent.") - if not isinstance(request.json.get("configuration"), dict): - config_type = type(request.json.get("configuration")) - raise errors.ValidationError( - message=f"The R clone configuration should be a dictionary, not {config_type.__name__}" - ) - if not request.json.get("source_path"): - raise errors.ValidationError(message="'source_path' is required to test the connection.") - validator.validate(request.json["configuration"], keep_sensitive=True) - result = await validator.test_connection(request.json["configuration"], request.json["source_path"]) + @validate(json=apispec.StorageSchemaTestConnectionPostRequest) + async def _test_connection( + request: Request, validator: RCloneValidator, body: apispec.StorageSchemaTestConnectionPostRequest + ) -> HTTPResponse: + validator.validate(body.configuration, keep_sensitive=True) + result = await validator.test_connection(body.configuration, body.source_path) if not result.success: raise errors.ValidationError(message=result.error) return empty(204) @@ -214,12 +207,13 @@ async def _test_connection(request: Request, validator: RCloneValidator) -> HTTP def validate(self) -> BlueprintFactoryResponse: """Validate an RClone config.""" - async def _validate(request: Request, validator: RCloneValidator) -> HTTPResponse: - if not request.json: + @validate(json=apispec.RCloneConfigValidate) + async def _validate( + request: Request, validator: RCloneValidator, body: apispec.RCloneConfigValidate + ) -> HTTPResponse: + if body.root is None: raise errors.ValidationError(message="The request body is empty. Please provide a valid JSON object.") - if not isinstance(request.json, dict): - raise errors.ValidationError(message="The request body is not a valid JSON object.") - validator.validate(request.json, keep_sensitive=True) + validator.validate(body.root, keep_sensitive=True) return empty(204) return "/storage_schema/validate", ["POST"], _validate @@ -227,12 +221,11 @@ async def _validate(request: Request, validator: RCloneValidator) -> HTTPRespons def obscure(self) -> BlueprintFactoryResponse: """Obscure values in config.""" - async def _obscure(request: Request, validator: RCloneValidator) -> JSONResponse: - if not request.json: - raise errors.ValidationError(message="The request body is empty. Please provide a valid JSON object.") - if not isinstance(request.json, dict): - raise errors.ValidationError(message="The request body is not a valid JSON object.") - config = await validator.obscure_config(request.json) - return json(config) + @validate(json=apispec.StorageSchemaObscurePostRequest) + async def _obscure( + request: Request, validator: RCloneValidator, body: apispec.StorageSchemaObscurePostRequest + ) -> JSONResponse: + config = await validator.obscure_config(body.configuration) + return validated_json(apispec.RCloneConfigValidate, config) return "/storage_schema/obscure", ["POST"], _obscure diff --git a/test/bases/renku_data_services/data_api/test_projects.py b/test/bases/renku_data_services/data_api/test_projects.py index 390dd3f33..32da82357 100644 --- a/test/bases/renku_data_services/data_api/test_projects.py +++ b/test/bases/renku_data_services/data_api/test_projects.py @@ -89,7 +89,7 @@ async def test_project_creation(sanic_client, user_headers, regular_user: UserIn # same as above, but using namespace/slug to retreive the pr _, response = await sanic_client.get( - f"/api/data/projects/{payload['namespace']}/{payload['slug']}", headers=user_headers + f"/api/data/namespaces/{payload['namespace']}/projects/{payload['slug']}", headers=user_headers ) assert response.status_code == 200, response.text diff --git a/test/bases/renku_data_services/data_api/test_storage.py b/test/bases/renku_data_services/data_api/test_storage.py index dcd15e2a5..ae63560e5 100644 --- a/test/bases/renku_data_services/data_api/test_storage.py +++ b/test/bases/renku_data_services/data_api/test_storage.py @@ -500,10 +500,12 @@ async def test_storage_patch_unauthorized(storage_test_client, valid_storage_pay async def test_storage_obscure(storage_test_client) -> None: storage_test_client, _ = storage_test_client body = { - "type": "seafile", - "provider": "Other", - "user": "abcdefg", - "pass": "123456", + "configuration": { + "type": "seafile", + "provider": "Other", + "user": "abcdefg", + "pass": "123456", + } } _, res = await storage_test_client.post("/api/data/storage_schema/obscure", data=json.dumps(body)) assert res.status_code == 200 diff --git a/test/components/renku_data_services/authz/test_authorization.py b/test/components/renku_data_services/authz/test_authorization.py index 830129be4..a70987ba1 100644 --- a/test/components/renku_data_services/authz/test_authorization.py +++ b/test/components/renku_data_services/authz/test_authorization.py @@ -145,8 +145,8 @@ async def test_listing_users_with_access(app_config: Config, public_project: boo project1_id = ULID() project1 = Project( id=project1_id, - name=project1_id, - slug=project1_id, + name=str(project1_id), + slug=str(project1_id), namespace=Namespace( project_owner.id, project_owner.id, @@ -160,8 +160,8 @@ async def test_listing_users_with_access(app_config: Config, public_project: boo project2_id = ULID() project2 = Project( id=project2_id, - name=project2_id, - slug=project2_id, + name=str(project2_id), + slug=str(project2_id), namespace=Namespace( regular_user2.id, regular_user2.id, @@ -190,9 +190,11 @@ async def test_listing_projects_with_access(app_config: Config, bootstrap_admins public_project_id = ULID() private_project_id1 = ULID() private_project_id2 = ULID() + public_project_id_str = str(public_project_id) private_project_id1_str = str(private_project_id1) private_project_id2_str = str(private_project_id2) + project_owner = regular_user1 namespace = Namespace( project_owner.id, @@ -205,24 +207,24 @@ async def test_listing_projects_with_access(app_config: Config, bootstrap_admins assert regular_user2.id public_project = Project( id=public_project_id, - name=public_project_id, - slug=public_project_id, + name=public_project_id_str, + slug=public_project_id_str, namespace=namespace, visibility=Visibility.PUBLIC, created_by=project_owner.id, ) private_project1 = Project( id=private_project_id1, - name=private_project_id1, - slug=private_project_id1, + name=private_project_id1_str, + slug=private_project_id1_str, namespace=namespace, visibility=Visibility.PRIVATE, created_by=project_owner.id, ) private_project2 = Project( id=private_project_id2, - name=private_project_id2, - slug=private_project_id2, + name=private_project_id2_str, + slug=private_project_id2_str, namespace=namespace, visibility=Visibility.PRIVATE, created_by=project_owner.id,