Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge with stac-fastapi-pgstac settings and defaults to all extensions #24

Merged
merged 3 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 61 additions & 95 deletions runtimes/eoapi/stac/eoapi/stac/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from fastapi.responses import ORJSONResponse
from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import (
EmptyRequest,
ItemCollectionUri,
create_get_request_model,
create_post_request_model,
Expand All @@ -25,17 +24,14 @@
SearchFilterExtension,
SortExtension,
TokenPaginationExtension,
TransactionExtension,
)
from stac_fastapi.extensions.core.fields import FieldsConformanceClasses
from stac_fastapi.extensions.core.free_text import FreeTextConformanceClasses
from stac_fastapi.extensions.core.query import QueryConformanceClasses
from stac_fastapi.extensions.core.sort import SortConformanceClasses
from stac_fastapi.extensions.third_party import BulkTransactionExtension
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
from stac_fastapi.pgstac.extensions import QueryExtension
from stac_fastapi.pgstac.extensions.filter import FiltersClient
from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient
from stac_fastapi.pgstac.types.search import PgstacSearch
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
Expand All @@ -44,8 +40,9 @@
from starlette.templating import Jinja2Templates
from starlette_cramjam.middleware import CompressionMiddleware

from . import __version__ as eoapi_devseed_version
from .client import PgSTACClient
from .config import ApiSettings
from .config import Settings
from .extension import TiTilerExtension
from .logs import init_logging

Expand All @@ -58,115 +55,78 @@
)
templates = Jinja2Templates(env=jinja2_env)

api_settings = ApiSettings()
settings = Settings()
auth_settings = OpenIdConnectSettings()
settings = api_settings.load_postgres_settings()

enabled_extensions = api_settings.extensions or []

# Logs
init_logging(debug=api_settings.debug)
init_logging(debug=settings.debug)
logger = logging.getLogger(__name__)

# Extensions
# application extensions
application_extensions_map = {
"transaction": TransactionExtension(
client=TransactionsClient(),
settings=settings,
response_class=ORJSONResponse,
),
"bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()),
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the transaction extensions because I assume we will:

  • either add the ingestor api
  • add them back once we have some auth defined

if "titiler" in enabled_extensions and api_settings.titiler_endpoint:
application_extensions_map["titiler"] = TiTilerExtension(
titiler_endpoint=api_settings.titiler_endpoint
application_extensions = []

if settings.titiler_endpoint:
application_extensions.append(
TiTilerExtension(titiler_endpoint=settings.titiler_endpoint)
)

# search extensions
search_extensions_map = {
"query": QueryExtension(),
"sort": SortExtension(),
"fields": FieldsExtension(),
"filter": SearchFilterExtension(client=FiltersClient()),
"pagination": TokenPaginationExtension(),
}
search_extensions = [
QueryExtension(),
SortExtension(),
FieldsExtension(),
SearchFilterExtension(client=FiltersClient()),
TokenPaginationExtension(),
]

# collection_search extensions
cs_extensions_map = {
"query": QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
"sort": SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
"fields": FieldsExtension(
conformance_classes=[FieldsConformanceClasses.COLLECTIONS]
),
"filter": CollectionSearchFilterExtension(client=FiltersClient()),
"free_text": FreeTextExtension(
cs_extensions = [
QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]),
SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]),
FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]),
CollectionSearchFilterExtension(client=FiltersClient()),
FreeTextExtension(
conformance_classes=[FreeTextConformanceClasses.COLLECTIONS],
),
"pagination": OffsetPaginationExtension(),
}
OffsetPaginationExtension(),
]

# item_collection extensions
itm_col_extensions_map = {
"query": QueryExtension(
itm_col_extensions = [
QueryExtension(
conformance_classes=[QueryConformanceClasses.ITEMS],
),
"sort": SortExtension(
SortExtension(
conformance_classes=[SortConformanceClasses.ITEMS],
),
"fields": FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]),
"filter": ItemCollectionFilterExtension(client=FiltersClient()),
"pagination": TokenPaginationExtension(),
}

application_extensions = [
extension
for key, extension in application_extensions_map.items()
if key in enabled_extensions
FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]),
ItemCollectionFilterExtension(client=FiltersClient()),
TokenPaginationExtension(),
]

# Request Models
# /search models
search_extensions = [
extension
for key, extension in search_extensions_map.items()
if key in enabled_extensions
]
post_request_model = create_post_request_model(
search_post_model = create_post_request_model(
search_extensions, base_model=PgstacSearch
)
get_request_model = create_get_request_model(search_extensions)
search_get_model = create_get_request_model(search_extensions)
application_extensions.extend(search_extensions)

# /collections/{collectionId}/items model
items_get_request_model = ItemCollectionUri
itm_col_extensions = [
extension
for key, extension in itm_col_extensions_map.items()
if key in enabled_extensions
]
if itm_col_extensions:
items_get_request_model = create_request_model(
model_name="ItemCollectionUri",
base_model=ItemCollectionUri,
extensions=itm_col_extensions,
request_type="GET",
)
application_extensions.extend(itm_col_extensions)
items_get_model = create_request_model(
model_name="ItemCollectionUri",
base_model=ItemCollectionUri,
extensions=itm_col_extensions,
request_type="GET",
)
application_extensions.extend(itm_col_extensions)

# /collections model
collections_get_request_model = EmptyRequest
if "collection_search" in enabled_extensions:
cs_extensions = [
extension
for key, extension in cs_extensions_map.items()
if key in enabled_extensions
]
collection_search_extension = CollectionSearchExtension.from_extensions(
cs_extensions
)
collections_get_request_model = collection_search_extension.GET
application_extensions.append(collection_search_extension)
collection_search_extension = CollectionSearchExtension.from_extensions(cs_extensions)
collections_get_model = collection_search_extension.GET
application_extensions.append(collection_search_extension)


@asynccontextmanager
Expand All @@ -179,38 +139,44 @@ async def lifespan(app: FastAPI):

# Middlewares
middlewares = [Middleware(CompressionMiddleware)]
if api_settings.cors_origins:
if settings.cors_origins:
middlewares.append(
Middleware(
CORSMiddleware,
allow_origins=api_settings.cors_origins,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=api_settings.cors_methods,
allow_methods=settings.cors_methods,
allow_headers=["*"],
)
)

api = StacApi(
app=FastAPI(
title=api_settings.name,
title=settings.stac_fastapi_title,
description=settings.stac_fastapi_description,
version=eoapi_devseed_version,
lifespan=lifespan,
openapi_url="/api",
docs_url="/api.html",
openapi_url=settings.openapi_url,
docs_url=settings.docs_url,
redoc_url=None,
swagger_ui_init_oauth={
"clientId": auth_settings.client_id,
"usePkceWithAuthorizationCodeGrant": auth_settings.use_pkce,
},
),
title=api_settings.name,
description=api_settings.name,
api_version=eoapi_devseed_version,
settings=settings,
extensions=application_extensions,
client=PgSTACClient(pgstac_search_model=post_request_model),
items_get_request_model=items_get_request_model,
search_get_request_model=get_request_model,
search_post_request_model=post_request_model,
collections_get_request_model=collections_get_request_model,
client=PgSTACClient( # type: ignore
landing_page_id=settings.stac_fastapi_landing_id,
title=settings.stac_fastapi_title,
description=settings.stac_fastapi_description,
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
pgstac_search_model=search_post_model,
),
items_get_request_model=items_get_model,
search_get_request_model=search_get_model,
search_post_request_model=search_post_model,
collections_get_request_model=collections_get_model,
response_class=ORJSONResponse,
middlewares=middlewares,
)
Expand Down
78 changes: 26 additions & 52 deletions runtimes/eoapi/stac/eoapi/stac/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import base64
import json
from typing import List, Optional
from typing import Any, Optional

import boto3
from pydantic import field_validator
from pydantic_settings import BaseSettings
from stac_fastapi.pgstac.config import Settings
from pydantic import model_validator
from stac_fastapi.pgstac import config


def get_secret_dict(secret_name: str):
Expand All @@ -33,59 +32,34 @@ def get_secret_dict(secret_name: str):
return json.loads(base64.b64decode(get_secret_value_response["SecretBinary"]))


class ApiSettings(BaseSettings):
"""API settings"""
class Settings(config.Settings):
"""Extent stac-fastapi-pgstac settings"""

stac_fastapi_title: str = "eoAPI-stac"
stac_fastapi_description: str = "Custom stac-fastapi application for eoAPI-Devseed"
stac_fastapi_landing_id: str = "eoapi-devseed-stac"

name: str = "eoAPI-stac"
cors_origins: str = "*"
cors_methods: str = "GET,POST,OPTIONS"
cachecontrol: str = "public, max-age=3600"
debug: bool = False

pgstac_secret_arn: Optional[str] = None

titiler_endpoint: Optional[str] = None

extensions: List[str] = [
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
"filter",
"query",
"sort",
"fields",
"pagination",
"titiler",
"free_text",
"transaction",
# "bulk_transactions",
"collection_search",
]

@field_validator("cors_origins")
def parse_cors_origin(cls, v):
"""Parse CORS origins."""
return [origin.strip() for origin in v.split(",")]

@field_validator("cors_methods")
def parse_cors_methods(cls, v):
"""Parse CORS methods."""
return [method.strip() for method in v.split(",")]

def load_postgres_settings(self) -> "Settings":
"""Load postgres connection params from AWS secret"""

if self.pgstac_secret_arn:
secret = get_secret_dict(self.pgstac_secret_arn)

return Settings(
postgres_host_reader=secret["host"],
postgres_host_writer=secret["host"],
postgres_dbname=secret["dbname"],
postgres_user=secret["username"],
postgres_pass=secret["password"],
postgres_port=secret["port"],
debug: bool = False

@model_validator(mode="before")
def get_postgres_setting(cls, data: Any) -> Any:
if arn := data.get("pgstac_secret_arn"):
secret = get_secret_dict(arn)
data.update(
{
"postgres_host_reader": secret["host"],
"postgres_host_writer": secret["host"],
"postgres_dbname": secret["dbname"],
"postgres_user": secret["username"],
"postgres_pass": secret["password"],
"postgres_port": secret["port"],
}
)
else:
return Settings()

model_config = {
"env_file": ".env",
"extra": "allow",
}
return data