diff --git a/fastapi_security/oauth2.py b/fastapi_security/oauth2.py index 224814f..ae78d35 100644 --- a/fastapi_security/oauth2.py +++ b/fastapi_security/oauth2.py @@ -41,9 +41,10 @@ def __init__(self): def init( self, jwks_url: str, - audiences: Iterable[str], + audiences: Optional[Union[str, Iterable[str]]], *, jwks_cache_period: int = DEFAULT_JWKS_RESPONSE_CACHE_PERIOD, + decode_options: dict = None ): """Set up Oauth 2.0 JWT validation @@ -52,9 +53,11 @@ def init( The JWKS endpoint to fetch the public keys from. Usually in the format: "https://domain/.well-known/jwks.json" audiences: - Accepted `aud` values for incoming access tokens + Accepted `aud` values for incoming access tokens. Could be a list of string, a string or None. jwks_cache_period: How many seconds to cache the JWKS response. Defaults to 1 hour. + decode_options: + Other options for PyJWT's decode function. """ if aiohttp is None: raise MissingDependency( @@ -66,7 +69,8 @@ def init( ) self._jwks_url = jwks_url self._jwks_cache_period = float(jwks_cache_period) - self._audiences = list(audiences) + self._audiences = audiences + self._decode_options = decode_options def is_configured(self) -> bool: return bool(self._jwks_url) @@ -161,4 +165,4 @@ def _decode_jwt_token( self, public_key: _RSAPublicKey, access_token: str ) -> Dict[str, Any]: # NOTE: jwt.decode has erroneously set key: str - return jwt.decode(access_token, key=public_key, audience=self._audiences, algorithms=["RS256"]) # type: ignore + return jwt.decode(access_token, key=public_key, audience=self._audiences, algorithms=["RS256"], **self._decode_options) # type: ignore