diff --git a/src/prometheus_fastapi_instrumentator/instrumentation.py b/src/prometheus_fastapi_instrumentator/instrumentation.py index fcd4f75..17e51db 100644 --- a/src/prometheus_fastapi_instrumentator/instrumentation.py +++ b/src/prometheus_fastapi_instrumentator/instrumentation.py @@ -3,10 +3,21 @@ import os import re import warnings +from base64 import b64encode from enum import Enum -from typing import Any, Awaitable, Callable, List, Optional, Sequence, Union, cast +from typing import ( + Any, + Awaitable, + Callable, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from prometheus_client import ( CONTENT_TYPE_LATEST, REGISTRY, @@ -227,6 +238,7 @@ def expose( endpoint: str = "/metrics", include_in_schema: bool = True, tags: Optional[List[Union[str, Enum]]] = None, + basic_auth: Optional[Tuple[str, str]] = None, **kwargs: Any, ) -> "PrometheusFastApiInstrumentator": """Exposes endpoint for metrics. @@ -247,6 +259,9 @@ def expose( tags (List[str], optional): If you manage your routes with tags. Defaults to None. + basic_auth (Tuple[str, str], optional): username and password for + HTTP basic authentication. Disabled if None. + kwargs: Will be passed to FastAPI route annotation. Returns: @@ -256,10 +271,27 @@ def expose( if self.should_respect_env_var and not self._should_instrumentate(): return self + authorization_value = None + if basic_auth is not None: + username, password = basic_auth + encoded_cred = b64encode(f"{username}:{password}".encode("utf-8")).decode( + "ascii" + ) + authorization_value = f"Basic {encoded_cred}" + @app.get(endpoint, include_in_schema=include_in_schema, tags=tags, **kwargs) def metrics(request: Request) -> Response: """Endpoint that serves Prometheus metrics.""" + authorization_header = request.headers.get("authorization", None) + if authorization_header != authorization_value: + raise HTTPException( + status_code=401, + headers={ + "WWW-Authenticate": 'Basic realm="Access to metrics endpoint"' + }, + ) + ephemeral_registry = self.registry if "PROMETHEUS_MULTIPROC_DIR" in os.environ: ephemeral_registry = CollectorRegistry() diff --git a/tests/test_instrumentator_expose.py b/tests/test_instrumentator_expose.py index b75f6e6..34d78f4 100644 --- a/tests/test_instrumentator_expose.py +++ b/tests/test_instrumentator_expose.py @@ -1,6 +1,7 @@ from fastapi import FastAPI from prometheus_client import REGISTRY from requests import Response as TestClientResponse +from requests.auth import HTTPBasicAuth from starlette.testclient import TestClient from prometheus_fastapi_instrumentator import Instrumentator @@ -76,3 +77,20 @@ def test_expose_custom_path(): response = get_response(client, "/custom_metrics") assert response.status_code == 200 assert b"http_request" in response.content + + +def test_expose_basic_auth(): + username = "hello" + password = "mom" + app = create_app() + Instrumentator().instrument(app).expose(app, basic_auth=(username, password)) + client = TestClient(app) + + response = client.get("/metrics") + assert response.status_code == 401 + assert b"http_request" not in response.content + + auth = HTTPBasicAuth(username, password) + response = client.get("/metrics", auth=auth) + assert response.status_code == 200 + assert b"http_request" in response.content