Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Issue #69] Remove the BASE_RESPONSE_SCHEMA #70

Merged
merged 4 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
656 changes: 180 additions & 476 deletions api/openapi.generated.yml

Large diffs are not rendered by default.

28 changes: 17 additions & 11 deletions api/src/api/healthcheck.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
import logging
from typing import Tuple

from apiflask import APIBlueprint
from flask import current_app
from sqlalchemy import text
from werkzeug.exceptions import ServiceUnavailable

import src.adapters.db as db
import src.adapters.db.flask_db as flask_db
from src.api import response
from src.api.schemas.extension import Schema, fields
from src.api.route_utils import raise_flask_error
from src.api.schemas.extension import fields
from src.api.schemas.response_schema import AbstractResponseSchema

logger = logging.getLogger(__name__)


class HealthcheckSchema(Schema):
message = fields.String()
class HealthcheckResponseSchema(AbstractResponseSchema):
# We don't have any data to return with the healthcheck endpoint
data = fields.MixinField(metadata={"example": None})


healthcheck_blueprint = APIBlueprint("healthcheck", __name__, tag="Health")


@healthcheck_blueprint.get("/health")
@healthcheck_blueprint.output(HealthcheckSchema)
@healthcheck_blueprint.output(HealthcheckResponseSchema)
@healthcheck_blueprint.doc(responses=[200, ServiceUnavailable.code])
def health() -> Tuple[response.ApiResponse, int]:
@flask_db.with_db_session()
def health(db_session: db.Session) -> response.ApiResponse:
try:
with flask_db.get_db(current_app).get_connection() as conn:
assert conn.scalar(text("SELECT 1 AS healthy")) == 1
return response.ApiResponse(message="Service healthy"), 200
with db_session.begin():
if db_session.scalar(text("SELECT 1 AS healthy")) != 1:
raise Exception("Connection to Postgres DB failure")

except Exception:
logger.exception("Connection to DB failure")
return response.ApiResponse(message="Service unavailable"), ServiceUnavailable.code
raise_flask_error(ServiceUnavailable.code, message="Service Unavailable")

return response.ApiResponse(message="Service healthy")
4 changes: 2 additions & 2 deletions api/src/api/opportunities_v0/opportunity_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
arg_name="feature_flag_config",
)
# many=True allows us to return a list of opportunity objects
@opportunity_blueprint.output(opportunity_schemas.OpportunityV0Schema(many=True))
@opportunity_blueprint.output(opportunity_schemas.OpportunitySearchResponseV0Schema)
@opportunity_blueprint.auth_required(api_key_auth)
@opportunity_blueprint.doc(description=SHARED_ALPHA_DESCRIPTION)
@flask_db.with_db_session()
Expand Down Expand Up @@ -65,7 +65,7 @@ def opportunity_search(


@opportunity_blueprint.get("/opportunities/<int:opportunity_id>")
@opportunity_blueprint.output(opportunity_schemas.OpportunityV0Schema)
@opportunity_blueprint.output(opportunity_schemas.OpportunityGetResponseV0Schema)
@opportunity_blueprint.auth_required(api_key_auth)
@opportunity_blueprint.doc(description=SHARED_ALPHA_DESCRIPTION)
@flask_db.with_db_session()
Expand Down
9 changes: 9 additions & 0 deletions api/src/api/opportunities_v0/opportunity_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from src.api.feature_flags.feature_flag import FeatureFlag
from src.api.feature_flags.feature_flag_config import FeatureFlagConfig, get_feature_flag_config
from src.api.schemas.extension import Schema, fields
from src.api.schemas.response_schema import AbstractResponseSchema, PaginationMixinSchema
from src.constants.lookup_constants import OpportunityCategoryLegacy
from src.pagination.pagination_schema import PaginationSchema, generate_sorting_schema

Expand Down Expand Up @@ -113,3 +114,11 @@ def post_load(self, data: dict, **kwargs: Any) -> FeatureFlagConfig:
feature_flag_config.enable_opportunity_log_msg = enable_opportunity_log_msg

return feature_flag_config


class OpportunityGetResponseV0Schema(AbstractResponseSchema):
data = fields.Nested(OpportunityV0Schema())


class OpportunitySearchResponseV0Schema(AbstractResponseSchema, PaginationMixinSchema):
data = fields.Nested(OpportunityV0Schema(many=True))
4 changes: 2 additions & 2 deletions api/src/api/opportunities_v0_1/opportunity_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
examples=examples,
)
# many=True allows us to return a list of opportunity objects
@opportunity_blueprint.output(opportunity_schemas.OpportunityV01Schema(many=True))
@opportunity_blueprint.output(opportunity_schemas.OpportunitySearchResponseV01Schema)
@opportunity_blueprint.auth_required(api_key_auth)
@opportunity_blueprint.doc(description=SHARED_ALPHA_DESCRIPTION)
@flask_db.with_db_session()
Expand All @@ -92,7 +92,7 @@ def opportunity_search(db_session: db.Session, search_params: dict) -> response.


@opportunity_blueprint.get("/opportunities/<int:opportunity_id>")
@opportunity_blueprint.output(opportunity_schemas.OpportunityV01Schema)
@opportunity_blueprint.output(opportunity_schemas.OpportunityGetResponseV01Schema)
@opportunity_blueprint.auth_required(api_key_auth)
@opportunity_blueprint.doc(description=SHARED_ALPHA_DESCRIPTION)
@flask_db.with_db_session()
Expand Down
9 changes: 9 additions & 0 deletions api/src/api/opportunities_v0_1/opportunity_schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from src.api.schemas.extension import Schema, fields, validators
from src.api.schemas.response_schema import AbstractResponseSchema, PaginationMixinSchema
from src.api.schemas.search_schema import StrSearchSchemaBuilder
from src.constants.lookup_constants import (
ApplicantType,
Expand Down Expand Up @@ -296,3 +297,11 @@ class OpportunitySearchRequestV01Schema(Schema):
),
required=True,
)


class OpportunityGetResponseV01Schema(AbstractResponseSchema):
data = fields.Nested(OpportunityV01Schema())


class OpportunitySearchResponseV01Schema(AbstractResponseSchema, PaginationMixinSchema):
data = fields.Nested(OpportunityV01Schema(many=True))
4 changes: 2 additions & 2 deletions api/src/api/opportunities_v1/opportunity_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
opportunity_schemas.OpportunitySearchRequestV1Schema, arg_name="search_params"
)
# many=True allows us to return a list of opportunity objects
@opportunity_blueprint.output(opportunity_schemas.OpportunityV1Schema(many=True))
@opportunity_blueprint.output(opportunity_schemas.OpportunitySearchResponseV1Schema)
@opportunity_blueprint.auth_required(api_key_auth)
@opportunity_blueprint.doc(description=SHARED_ALPHA_DESCRIPTION)
def opportunity_search(search_params: dict) -> response.ApiResponse:
Expand All @@ -53,7 +53,7 @@ def opportunity_search(search_params: dict) -> response.ApiResponse:


@opportunity_blueprint.get("/opportunities/<int:opportunity_id>")
@opportunity_blueprint.output(opportunity_schemas.OpportunityV1Schema)
@opportunity_blueprint.output(opportunity_schemas.OpportunityGetResponseV1Schema)
@opportunity_blueprint.auth_required(api_key_auth)
@opportunity_blueprint.doc(description=SHARED_ALPHA_DESCRIPTION)
@flask_db.with_db_session()
Expand Down
9 changes: 9 additions & 0 deletions api/src/api/opportunities_v1/opportunity_schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from src.api.schemas.extension import Schema, fields, validators
from src.api.schemas.response_schema import AbstractResponseSchema, PaginationMixinSchema
from src.api.schemas.search_schema import StrSearchSchemaBuilder
from src.constants.lookup_constants import (
ApplicantType,
Expand Down Expand Up @@ -296,3 +297,11 @@ class OpportunitySearchRequestV1Schema(Schema):
),
required=True,
)


class OpportunityGetResponseV1Schema(AbstractResponseSchema):
data = fields.Nested(OpportunityV1Schema())


class OpportunitySearchResponseV1Schema(AbstractResponseSchema, PaginationMixinSchema):
data = fields.Nested(OpportunityV1Schema(many=True))
47 changes: 37 additions & 10 deletions api/src/api/schemas/response_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,52 @@


class ValidationIssueSchema(Schema):
type = fields.String(metadata={"description": "The type of error"})
message = fields.String(metadata={"description": "The message to return"})
field = fields.String(metadata={"description": "The field that failed"})
type = fields.String(metadata={"description": "The type of error", "example": "invalid"})
message = fields.String(
metadata={"description": "The message to return", "example": "Not a valid string."}
)
field = fields.String(
metadata={"description": "The field that failed", "example": "summary.summary_description"}
)


class BaseResponseSchema(Schema):
message = fields.String(metadata={"description": "The message to return"})
class AbstractResponseSchema(Schema):
message = fields.String(metadata={"description": "The message to return", "example": "Success"})
data = fields.MixinField(metadata={"description": "The REST resource object"}, dump_default={})
status_code = fields.Integer(metadata={"description": "The HTTP status code"}, dump_default=200)
status_code = fields.Integer(
metadata={"description": "The HTTP status code", "example": 200}, dump_default=200
)


class ErrorResponseSchema(BaseResponseSchema):
errors = fields.List(fields.Nested(ValidationIssueSchema()), dump_default=[])
class WarningMixinSchema(Schema):
warnings = fields.List(
fields.Nested(ValidationIssueSchema()),
metadata={
"description": "A list of warnings - indicating something you may want to be aware of, but did not prevent handling of the request"
},
dump_default=[],
)


class ResponseSchema(BaseResponseSchema):
class PaginationMixinSchema(Schema):
pagination_info = fields.Nested(
PaginationInfoSchema(),
metadata={"description": "The pagination information for paginated endpoints"},
)

warnings = fields.List(fields.Nested(ValidationIssueSchema()), dump_default=[])

class ErrorResponseSchema(Schema):
data = fields.MixinField(
metadata={
"description": "Additional data that might be useful in resolving an error (see specific endpoints for details, this is used infrequently)",
"example": {},
},
dump_default={},
)
message = fields.String(
metadata={"description": "General description of the error", "example": "Error"}
)
status_code = fields.Integer(metadata={"description": "The HTTP status code of the error"})
errors = fields.List(
fields.Nested(ValidationIssueSchema()), metadata={"example": []}, dump_default=[]
)
2 changes: 1 addition & 1 deletion api/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def configure_app(app: APIFlask) -> None:
# Modify the response schema to instead use the format of our ApiResponse class
# which adds additional details to the object.
# https://apiflask.com/schema/#base-response-schema-customization
app.config["BASE_RESPONSE_SCHEMA"] = response_schema.ResponseSchema
# app.config["BASE_RESPONSE_SCHEMA"] = response_schema.ResponseSchema
app.config["HTTP_ERROR_SCHEMA"] = response_schema.ErrorResponseSchema
app.config["VALIDATION_ERROR_SCHEMA"] = response_schema.ErrorResponseSchema
app.config["SWAGGER_UI_CSS"] = "/static/swagger-ui.min.css"
Expand Down
6 changes: 4 additions & 2 deletions api/tests/src/api/test_healthcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
def test_get_healthcheck_200(client):
response = client.get("/health")
assert response.status_code == 200
assert response.get_json()["message"] == "Service healthy"


def test_get_healthcheck_503_db_bad_state(client, monkeypatch):
# Make fetching the DB session fail
def err_method(*args):
raise Exception("Fake Error")

# Mock db.DB.get_session method to fail
monkeypatch.setattr(db.DBClient, "get_connection", err_method)
# Mock db_session.Scalar to fail
monkeypatch.setattr(db.Session, "scalar", err_method)

response = client.get("/health")
assert response.status_code == 503
assert response.get_json()["message"] == "Service Unavailable"
8 changes: 6 additions & 2 deletions api/tests/src/api/test_route_error_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from src.api.response import ApiResponse, ValidationErrorDetail
from src.api.route_utils import raise_flask_error
from src.api.schemas.extension import Schema, fields
from src.api.schemas.response_schema import AbstractResponseSchema, WarningMixinSchema
from src.auth.api_key_auth import api_key_auth
from src.util.dict_util import flatten_dict
from tests.src.schemas.schema_validation_utils import (
Expand All @@ -35,10 +36,14 @@ def header(api_auth_token):
return {"X-Auth": api_auth_token}


class OutputSchema(Schema):
class OutputData(Schema):
output_val = fields.String()


class OutputSchema(AbstractResponseSchema, WarningMixinSchema):
data = fields.Nested(OutputData())


test_blueprint = APIBlueprint("test", __name__, tag="test")


Expand Down Expand Up @@ -111,7 +116,6 @@ def override(self):

assert resp.status_code == 500
resp_json = resp.get_json()
assert resp_json["data"] == {}
assert resp_json["errors"] == []
assert resp_json["message"] == "Internal Server Error"

Expand Down
10 changes: 7 additions & 3 deletions documentation/api/api-details.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ We would define the Marshmallow schema in-python like so:
```py
from enum import StrEnum
from src.api.schemas.extension import Schema, fields, validators
from src.api.schemas.response_schema import AbstractResponseSchema

class Suffix(StrEnum):
SENIOR = "SR"
Expand All @@ -65,6 +66,10 @@ class NameSchema(Schema):
class ExampleSchema(Schema):
name = fields.Nested(NameSchema())
birth_date = fields.Date(metadata={"description": "Their birth date"})

class ExampleResponseSchema(AbstractResponseSchema):
# Note that AbstractResponseSchema defines a message and status_code field as well
data = fields.Nested(ExampleSchema())
```

Anything specified in the metadata field is passed to the OpenAPI file that backs the swagger endpoint. The values
Expand All @@ -77,9 +82,8 @@ but it's recommended you try to populate the following:
You can specify validators that will be run when the request is being serialized by APIFlask

Defining a response works the exact same way however field validation does not occur on response, only formatting.
The response schema only dictates the data portion of the response, the rest of the response is defined in
[ResponseSchema](../../api/src/api/schemas/response_schema.py) which is connected to APIFlask via the `BASE_RESPONSE_SCHEMA` config.

To keep our response schema following a consistent pattern, we have a few base schema classes like [AbstractResponseSchema](../../api/src/api/schemas/response_schema.py)
that you can derive from for shared values like the message.

### Schema tips

Expand Down