Skip to content

Commit

Permalink
refactor: add validation to namespaces blueprint
Browse files Browse the repository at this point in the history
  • Loading branch information
Panaetius committed Aug 7, 2024
1 parent 2ed2828 commit fc708bd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 40 deletions.
26 changes: 18 additions & 8 deletions components/renku_data_services/base_models/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
from renku_data_services import errors


def validate_and_dump(
model: type[BaseModel],
data: Any,
exclude_none: bool = True,
) -> Any:
"""Validate and dump with a pydantic model, ensuring proper validation errors."""
try:
body = model.model_validate(data).model_dump(exclude_none=exclude_none, mode="json")
except PydanticValidationError as err:
parts = [".".join(str(i) for i in field["loc"]) + ": " + field["msg"] for field in err.errors()]
message = (
f"The server could not construct a valid response. Errors found in the following fields: {', '.join(parts)}"
)
raise errors.ProgrammingError(message=message) from err
return body


def validated_json(
model: type[BaseModel],
data: Any,
Expand All @@ -25,12 +42,5 @@ def validated_json(
If the input data fails validation, an HTTP status code 500 will be raised.
"""
try:
body = model.model_validate(data).model_dump(exclude_none=exclude_none, mode="json")
except PydanticValidationError as err:
parts = [".".join(str(i) for i in field["loc"]) + ": " + field["msg"] for field in err.errors()]
message = (
f"The server could not construct a valid response. Errors found in the following fields: {', '.join(parts)}"
)
raise errors.ProgrammingError(message=message) from err
body = validate_and_dump(model, data, exclude_none)
return json(body, status=status, headers=headers, content_type=content_type, dumps=dumps, **kwargs)
69 changes: 37 additions & 32 deletions components/renku_data_services/namespace/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass

from sanic import HTTPResponse, Request, json
from sanic import HTTPResponse, Request
from sanic.response import JSONResponse
from sanic_ext import validate

Expand All @@ -12,6 +12,7 @@
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.misc import validate_query
from renku_data_services.base_api.pagination import PaginationRequest, paginate
from renku_data_services.base_models.validation import validate_and_dump, validated_json
from renku_data_services.errors import errors
from renku_data_services.namespace import apispec
from renku_data_services.namespace.db import GroupRepository
Expand All @@ -35,7 +36,7 @@ async def _get_all(
) -> tuple[list[dict], int]:
groups, rec_count = await self.group_repo.get_groups(user=user, pagination=pagination)
return (
[apispec.GroupResponse.model_validate(g).model_dump(exclude_none=True, mode="json") for g in groups],
validate_and_dump(apispec.GroupResponseList, groups),
rec_count,
)

Expand All @@ -49,7 +50,7 @@ def post(self) -> BlueprintFactoryResponse:
@validate(json=apispec.GroupPostRequest)
async def _post(_: Request, user: base_models.APIUser, body: apispec.GroupPostRequest) -> JSONResponse:
result = await self.group_repo.insert_group(user=user, payload=body)
return json(apispec.GroupResponse.model_validate(result).model_dump(exclude_none=True, mode="json"), 201)
return validated_json(apispec.GroupResponse, result, 201)

return "/groups", ["POST"], _post

Expand All @@ -59,7 +60,7 @@ def get_one(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_one(_: Request, user: base_models.APIUser, slug: str) -> JSONResponse:
result = await self.group_repo.get_group(user=user, slug=slug)
return json(apispec.GroupResponse.model_validate(result).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.GroupResponse, result)

return "/groups/<slug:renku_slug>", ["GET"], _get_one

Expand All @@ -85,7 +86,7 @@ async def _patch(
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True)
res = await self.group_repo.update_group(user=user, slug=slug, payload=body_dict)
return json(apispec.GroupResponse.model_validate(res).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.GroupResponse, res)

return "/groups/<slug:renku_slug>", ["PATCH"], _patch

Expand All @@ -95,17 +96,18 @@ def get_all_members(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_all_members(_: Request, user: base_models.APIUser, slug: str) -> JSONResponse:
members = await self.group_repo.get_group_members(user, slug)
return json(
return validated_json(
apispec.GroupMemberResponseList,
[
apispec.GroupMemberResponse(
dict(
id=m.id,
email=m.email,
first_name=m.first_name,
last_name=m.last_name,
role=apispec.GroupRole(m.role.value),
).model_dump(exclude_none=True, mode="json")
)
for m in members
]
],
)

return "/groups/<slug:renku_slug>/members", ["GET"], _get_all_members
Expand All @@ -115,25 +117,24 @@ def update_members(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_authenticated
async def _update_members(
request: Request,
user: base_models.APIUser,
slug: str,
) -> JSONResponse:
async def _update_members(request: Request, user: base_models.APIUser, slug: str) -> JSONResponse:
# TODO: sanic validation does not support validating top-level json lists, switch this to @validate
# once sanic-org/sanic-ext/issues/198 is fixed
body_validated = apispec.GroupMemberPatchRequestList.model_validate(request.json)
res = await self.group_repo.update_group_members(
user=user,
slug=slug,
payload=body_validated,
)
return json(
return validated_json(
apispec.GroupMemberPatchRequestList,
[
apispec.GroupMemberPatchRequest(
dict(
id=m.member.user_id,
role=apispec.GroupRole(m.member.role.value),
).model_dump(exclude_none=True, mode="json")
)
for m in res
]
],
)

return "/groups/<slug:renku_slug>/members", ["PATCH"], _update_members
Expand Down Expand Up @@ -164,17 +165,20 @@ async def _get_namespaces(
nss, total_count = await self.group_repo.get_namespaces(
user=user, pagination=pagination, minimum_role=minimum_role
)
return [
apispec.NamespaceResponse(
id=ns.id,
name=ns.name,
slug=ns.latest_slug if ns.latest_slug else ns.slug,
created_by=ns.created_by,
creation_date=None, # NOTE: we do not save creation date in the DB
namespace_kind=apispec.NamespaceKind(ns.kind.value),
).model_dump(exclude_none=True, mode="json")
for ns in nss
], total_count
return validate_and_dump(
apispec.NamespaceResponseList,
[
dict(
id=ns.id,
name=ns.name,
slug=ns.latest_slug if ns.latest_slug else ns.slug,
created_by=ns.created_by,
creation_date=None, # NOTE: we do not save creation date in the DB
namespace_kind=apispec.NamespaceKind(ns.kind.value),
)
for ns in nss
],
), total_count

return "/namespaces", ["GET"], _get_namespaces

Expand All @@ -186,15 +190,16 @@ async def _get_namespace(_: Request, user: base_models.APIUser, slug: str) -> JS
ns = await self.group_repo.get_namespace_by_slug(user=user, slug=slug)
if not ns:
raise errors.MissingResourceError(message=f"The namespace with slug {slug} does not exist")
return json(
apispec.NamespaceResponse(
return validated_json(
apispec.NamespaceResponse,
dict(
id=ns.id,
name=ns.name,
slug=ns.latest_slug if ns.latest_slug else ns.slug,
created_by=ns.created_by,
creation_date=None, # NOTE: we do not save creation date in the DB
namespace_kind=apispec.NamespaceKind(ns.kind.value),
).model_dump(exclude_none=True, mode="json")
),
)

return "/namespaces/<slug:renku_slug>", ["GET"], _get_namespace

0 comments on commit fc708bd

Please sign in to comment.