diff --git a/fastapi_security/oauth2.py b/fastapi_security/oauth2.py
index 224814f..ae78d35 100644
--- a/fastapi_security/oauth2.py
+++ b/fastapi_security/oauth2.py
@@ -41,9 +41,10 @@ def __init__(self):
     def init(
         self,
         jwks_url: str,
-        audiences: Iterable[str],
+        audiences: Optional[Union[str, Iterable[str]]],
         *,
         jwks_cache_period: int = DEFAULT_JWKS_RESPONSE_CACHE_PERIOD,
+        decode_options: dict = None
     ):
         """Set up Oauth 2.0 JWT validation
 
@@ -52,9 +53,11 @@ def init(
                 The JWKS endpoint to fetch the public keys from. Usually in the
                 format: "https://domain/.well-known/jwks.json"
             audiences:
-                Accepted `aud` values for incoming access tokens
+                Accepted `aud` values for incoming access tokens. Could be a list of string, a string or None.
             jwks_cache_period:
                 How many seconds to cache the JWKS response. Defaults to 1 hour.
+            decode_options:
+                Other options for PyJWT's decode function.
         """
         if aiohttp is None:
             raise MissingDependency(
@@ -66,7 +69,8 @@ def init(
             )
         self._jwks_url = jwks_url
         self._jwks_cache_period = float(jwks_cache_period)
-        self._audiences = list(audiences)
+        self._audiences = audiences
+        self._decode_options = decode_options
 
     def is_configured(self) -> bool:
         return bool(self._jwks_url)
@@ -161,4 +165,4 @@ def _decode_jwt_token(
         self, public_key: _RSAPublicKey, access_token: str
     ) -> Dict[str, Any]:
         # NOTE: jwt.decode has erroneously set key: str
-        return jwt.decode(access_token, key=public_key, audience=self._audiences, algorithms=["RS256"])  # type: ignore
+        return jwt.decode(access_token, key=public_key, audience=self._audiences, algorithms=["RS256"], **self._decode_options)  # type: ignore