Skip to content

Commit

Permalink
refactor: add validation to connected services
Browse files Browse the repository at this point in the history
  • Loading branch information
Panaetius committed Aug 7, 2024
1 parent 3f004bf commit 4a50714
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions components/renku_data_services/connected_services/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import renku_data_services.base_models as base_models
from renku_data_services.base_api.auth import authenticate, only_admins, only_authenticated
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_models.validation import validated_json
from renku_data_services.connected_services import apispec
from renku_data_services.connected_services.apispec_base import AuthorizeParams, CallbackParams
from renku_data_services.connected_services.db import ConnectedServicesRepository
Expand All @@ -29,9 +30,7 @@ def get_all(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse:
clients = await self.connected_services_repo.get_oauth2_clients(user=user)
return json(
[apispec.Provider.model_validate(c).model_dump(exclude_none=True, mode="json") for c in clients]
)
return validated_json(apispec.ProviderList, clients)

return "/oauth2/providers", ["GET"], _get_all

Expand All @@ -41,7 +40,7 @@ def get_one(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_one(_: Request, user: base_models.APIUser, provider_id: str) -> JSONResponse:
client = await self.connected_services_repo.get_oauth2_client(provider_id=provider_id, user=user)
return json(apispec.Provider.model_validate(client).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Provider, client)

return "/oauth2/providers/<provider_id>", ["GET"], _get_one

Expand All @@ -53,7 +52,7 @@ def post(self) -> BlueprintFactoryResponse:
@validate(json=apispec.ProviderPost)
async def _post(_: Request, user: base_models.APIUser, body: apispec.ProviderPost) -> JSONResponse:
client = await self.connected_services_repo.insert_oauth2_client(user=user, new_client=body)
return json(apispec.Provider.model_validate(client).model_dump(exclude_none=True, mode="json"), 201)
return validated_json(apispec.Provider, client, 201)

return "/oauth2/providers", ["POST"], _post

Expand All @@ -70,7 +69,7 @@ async def _patch(
client = await self.connected_services_repo.update_oauth2_client(
user=user, provider_id=provider_id, **body_dict
)
return json(apispec.Provider.model_validate(client).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Provider, client)

return "/oauth2/providers/<provider_id>", ["PATCH"], _patch

Expand Down Expand Up @@ -142,9 +141,7 @@ def get_all(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse:
connections = await self.connected_services_repo.get_oauth2_connections(user=user)
return json(
[apispec.Connection.model_validate(c).model_dump(exclude_none=True, mode="json") for c in connections]
)
return validated_json(apispec.ConnectionList, connections)

return "/oauth2/connections", ["GET"], _get_all

Expand All @@ -156,7 +153,7 @@ async def _get_one(_: Request, user: base_models.APIUser, connection_id: str) ->
connection = await self.connected_services_repo.get_oauth2_connection(
connection_id=connection_id, user=user
)
return json(apispec.Connection.model_validate(connection).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Connection, connection)

return "/oauth2/connections/<connection_id>", ["GET"], _get_one

Expand All @@ -168,7 +165,7 @@ async def _get_account(_: Request, user: base_models.APIUser, connection_id: str
account = await self.connected_services_repo.get_oauth2_connected_account(
connection_id=connection_id, user=user
)
return json(apispec.ConnectedAccount.model_validate(account).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.ConnectedAccount, account)

return "/oauth2/connections/<connection_id>/account", ["GET"], _get_account

Expand Down

0 comments on commit 4a50714

Please sign in to comment.