Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Buildout CQL2 filter tooling #17

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b83885a
Rm CEL
alukach Dec 12, 2024
fe48623
Bring eoapi-auth-utils into this lib, customize to permit optional auth
alukach Dec 12, 2024
ebe494b
Add start to filters
alukach Dec 12, 2024
fa70d68
Update requirements
alukach Dec 12, 2024
cdd4040
Refactor config
alukach Dec 12, 2024
902a8bb
fix: correct auth handler
alukach Dec 12, 2024
29b4806
Lint cleanup
alukach Dec 12, 2024
475f95a
Crude working collections filter
alukach Dec 13, 2024
e5eee66
Passing tests
alukach Dec 13, 2024
8ddb3f6
Lint fixes
alukach Dec 13, 2024
d4757da
Cleanup
alukach Dec 13, 2024
a34c370
Mv from dataclasses to higher order functions
alukach Dec 13, 2024
e621750
Legibility refactor (more higher order functions instead of dataclasses)
alukach Dec 13, 2024
1069f41
Cleanup
alukach Dec 13, 2024
4429fe4
Add stub test
alukach Dec 13, 2024
3e4dd25
fix: correct annotation
alukach Dec 13, 2024
7021043
update test server to avoid conflicts
alukach Dec 13, 2024
166ca41
Add functional test for CQL2 filter
alukach Dec 13, 2024
67f6b3a
Continue test buildout
alukach Dec 15, 2024
2ea5ab9
Reorg utils
alukach Jan 2, 2025
00b6a8a
Initial pass at DI tooling
alukach Jan 2, 2025
08e4f40
Fix import
alukach Jan 3, 2025
acbfa63
Add test config
alukach Jan 3, 2025
f0c3c6a
Reorg tests
alukach Jan 3, 2025
680b32e
combine filter logic
alukach Jan 3, 2025
71e5ae1
update item endpoint check
alukach Jan 3, 2025
39dd77b
test item filter
alukach Jan 3, 2025
b89cfce
Rm maybe auth check for public endpoints
alukach Jan 3, 2025
76e5cb8
Refactor for legibility
alukach Jan 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ STAC Auth Proxy is a proxy API that mediates between the client and and some int

- 🔐 Selectively apply OIDC auth to some or all endpoints & methods
- 📖 Augments [OpenAPI](https://swagger.io/specification/) with auth information, keeping auto-generated docs (e.g. [Swagger UI](https://swagger.io/tools/swagger-ui/)) accurate
- 💂‍♀️ Custom policies enforce complex access controls, defined with [Common Expression Language (CEL)](https://cel.dev/)

## Installation

Expand Down
12 changes: 9 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@ classifiers = [
dependencies = [
"authlib>=1.3.2",
"brotli>=1.1.0",
"cel-python>=0.1.5",
"eoapi-auth-utils>=0.4.0",
"cql2>=0.3.2",
"fastapi>=0.115.5",
"httpx>=0.28.0",
"jinja2>=3.1.4",
"pydantic-settings>=2.6.1",
"pyjwt>=2.10.1",
"uvicorn>=0.32.1",
]
description = "STAC authentication proxy with FastAPI"
keywords = ["STAC", "FastAPI", "Authentication", "Proxy"]
license = {file = "LICENSE"}
name = "stac-auth-proxy"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
version = "0.1.0"

[tool.coverage.run]
Expand All @@ -42,6 +43,11 @@ requires = ["hatchling>=1.12.0"]
dev = [
"jwcrypto>=1.5.6",
"pre-commit>=3.5.0",
"pytest-asyncio>=0.25.1",
"pytest-cov>=5.0.0",
"pytest>=8.3.3",
]

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "strict"
77 changes: 37 additions & 40 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import logging
from typing import Optional

from eoapi.auth_utils import OpenIdConnectAuth
from fastapi import Depends, FastAPI
from fastapi import FastAPI, Security

from .auth import OpenIdConnectAuth
from .config import Settings
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
from .handlers import ReverseProxyHandler, build_openapi_spec_handler
from .middleware import AddProcessTimeHeaderMiddleware

logger = logging.getLogger(__name__)
Expand All @@ -23,61 +23,58 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
settings = settings or Settings()

app = FastAPI(
openapi_url=None,
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
)
app.add_middleware(AddProcessTimeHeaderMiddleware)

auth_scheme = OpenIdConnectAuth(
openid_configuration_url=str(settings.oidc_discovery_url)
).valid_token_dependency

if settings.guard:
logger.info("Wrapping auth scheme")
auth_scheme = settings.guard(auth_scheme)

if settings.debug:
app.add_api_route(
"/_debug",
lambda: {"settings": settings},
methods=["GET"],
)

proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
openapi_handler = OpenApiSpecHandler(
proxy=proxy_handler, oidc_config_url=str(settings.oidc_discovery_url)
# Tooling
auth_scheme = OpenIdConnectAuth(
openid_configuration_url=settings.oidc_discovery_url
)
proxy_handler = ReverseProxyHandler(
upstream=str(settings.upstream_url),
auth_dependency=auth_scheme.maybe_validated_user,
collections_filter=settings.collections_filter,
items_filter=settings.items_filter,
)
openapi_handler = build_openapi_spec_handler(
proxy=proxy_handler,
oidc_config_url=str(settings.oidc_discovery_url),
)

# Endpoints that are explicitely marked private
for path, methods in settings.private_endpoints.items():
app.add_api_route(
path,
(
proxy_handler.stream
if path != settings.openapi_spec_endpoint
else openapi_handler.dispatch
),
methods=methods,
dependencies=[Depends(auth_scheme)],
)

# Endpoints that are explicitely marked as public
for path, methods in settings.public_endpoints.items():
app.add_api_route(
path,
(
proxy_handler.stream
if path != settings.openapi_spec_endpoint
else openapi_handler.dispatch
),
methods=methods,
)
# Configure security dependency for explicitely specified endpoints
for path_methods, dependencies in [
(settings.private_endpoints, [Security(auth_scheme.validated_user)]),
(settings.public_endpoints, []),
]:
for path, methods in path_methods.items():
endpoint = (
openapi_handler
if path == settings.openapi_spec_endpoint
else proxy_handler.stream
)
app.add_api_route(
path,
endpoint=endpoint,
methods=methods,
dependencies=dependencies,
)

# Catchall for remainder of the endpoints
app.add_api_route(
"/{path:path}",
proxy_handler.stream,
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
dependencies=([] if settings.default_public else [Depends(auth_scheme)]),
dependencies=(
[] if settings.default_public else [Security(auth_scheme.validated_user)]
),
)

return app
130 changes: 130 additions & 0 deletions src/stac_auth_proxy/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""OIDC authentication module for validating JWTs."""

import json
import logging
import urllib.request
from dataclasses import dataclass, field
from typing import Annotated, Optional, Sequence

import jwt
from fastapi import HTTPException, Security, security, status
from fastapi.security.base import SecurityBase
from pydantic import HttpUrl

logger = logging.getLogger(__name__)


@dataclass
class OpenIdConnectAuth:
"""OIDC authentication class to generate auth handlers."""

openid_configuration_url: HttpUrl
openid_configuration_internal_url: Optional[HttpUrl] = None
allowed_jwt_audiences: Optional[Sequence[str]] = None

# Generated attributes
auth_scheme: SecurityBase = field(init=False)
jwks_client: jwt.PyJWKClient = field(init=False)

def __post_init__(self):
"""Initialize the OIDC authentication class."""
logger.debug("Requesting OIDC config")
origin_url = str(
self.openid_configuration_internal_url or self.openid_configuration_url
)
with urllib.request.urlopen(origin_url) as response:
if response.status != 200:
logger.error(
"Received a non-200 response when fetching OIDC config: %s",
response.text,
)
raise OidcFetchError(
f"Request for OIDC config failed with status {response.status}"
)
oidc_config = json.load(response)
self.jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"])

self.auth_scheme = security.OpenIdConnect(
openIdConnectUrl=str(self.openid_configuration_url),
auto_error=False,
)

# Update annotations to support FastAPI's dependency injection
for endpoint in [self.validated_user, self.maybe_validated_user]:
endpoint.__annotations__["auth_header"] = Annotated[
str,
Security(self.auth_scheme),
]

def maybe_validated_user(
self,
auth_header: Annotated[str, Security(...)],
required_scopes: security.SecurityScopes,
):
"""Dependency to validate an OIDC token."""
return self.validated_user(auth_header, required_scopes, auto_error=False)

def validated_user(
self,
auth_header: Annotated[str, Security(...)],
required_scopes: security.SecurityScopes,
auto_error: bool = True,
):
"""Dependency to validate an OIDC token."""
if not auth_header:
if auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated",
)
return None

# Extract token from header
token_parts = auth_header.split(" ")
if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
logger.error(f"Invalid token: {auth_header}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
[_, token] = token_parts

# Parse & validate token
try:
key = self.jwks_client.get_signing_key_from_jwt(token).key
payload = jwt.decode(
token,
key,
algorithms=["RS256"],
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
audience=self.allowed_jwt_audiences,
)
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
logger.exception(f"InvalidTokenError: {e=}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
) from e

# Validate scopes (if required)
for scope in required_scopes.scopes:
if scope not in payload["scope"]:
if auto_error:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={
"WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"'
},
)
return None

return payload


class OidcFetchError(Exception):
"""Error fetching OIDC configuration."""

pass
21 changes: 13 additions & 8 deletions src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
from typing import Optional, Sequence, TypeAlias

from pydantic import BaseModel
from pydantic import BaseModel, Field
from pydantic.networks import HttpUrl
from pydantic_settings import BaseSettings, SettingsConfigDict

Expand All @@ -14,15 +14,15 @@ class ClassInput(BaseModel):
"""Input model for dynamically loading a class or function."""

cls: str
args: Optional[Sequence[str]] = []
kwargs: Optional[dict[str, str]] = {}
args: Sequence[str] = Field(default_factory=list)
kwargs: dict[str, str] = Field(default_factory=dict)

def __call__(self, token_dependency):
"""Dynamically load a class and instantiate it with kwargs."""
def __call__(self):
"""Dynamically load a class and instantiate it with args & kwargs."""
module_path, class_name = self.cls.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
return cls(*self.args, **self.kwargs, token_dependency=token_dependency)
return cls(*self.args, **self.kwargs)


class Settings(BaseSettings):
Expand All @@ -48,6 +48,11 @@ class Settings(BaseSettings):
public_endpoints: EndpointMethods = {"/api.html": ["GET"], "/api": ["GET"]}
openapi_spec_endpoint: Optional[str] = None

model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_")
collections_filter: Optional[ClassInput] = None
collections_filter_endpoints: Optional[EndpointMethods] = {
"/collections": ["GET"],
"/collections/{collection_id}": ["GET"],
}
items_filter: Optional[ClassInput] = None

guard: Optional[ClassInput] = None
model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_")
5 changes: 5 additions & 0 deletions src/stac_auth_proxy/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""CQL2 filter generators."""

from .template import Template

__all__ = ["Template"]
42 changes: 42 additions & 0 deletions src/stac_auth_proxy/filters/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Generate CQL2 filter expressions via Jinja2 templating."""

from typing import Annotated, Any

from cql2 import Expr
from fastapi import Request
from jinja2 import BaseLoader, Environment

from ..utils.requests import extract_variables


def Template(template_str: str):
"""Generate CQL2 filter expressions via Jinja2 templating."""
env = Environment(loader=BaseLoader).from_string(template_str)

async def dependency(
request: Request,
auth_token: Annotated[dict[str, Any], ...],
) -> Expr:
"""Render a CQL2 filter expression with the request and auth token."""
# TODO: How to handle the case where auth_token is null?
context = {
"req": {
"path": request.url.path,
"method": request.method,
"query_params": dict(request.query_params),
"path_params": extract_variables(request.url.path),
"headers": dict(request.headers),
"body": (
await request.json()
if request.headers.get("content-type") == "application/json"
else (await request.body()).decode()
),
},
"token": auth_token,
}
cql2_str = env.render(**context)
cql2_expr = Expr(cql2_str)
cql2_expr.validate()
return cql2_expr

return dependency
5 changes: 0 additions & 5 deletions src/stac_auth_proxy/guards/__init__.py

This file was deleted.

Loading
Loading