Skip to content

Commit ece9cf9

Browse files
committed
run linting
1 parent 091d660 commit ece9cf9

File tree

5 files changed

+37
-33
lines changed

5 files changed

+37
-33
lines changed

tests/conftest.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -113,24 +113,20 @@ def tls_ca_ssl_context(tls_certificate_authority: trustme.CA) -> ssl.SSLContext:
113113
tls_certificate_authority.configure_trust(ssl_ctx)
114114
return ssl_ctx
115115

116+
116117
@pytest.fixture
117-
def tls_client_ssl_context(tls_certificate_authority: trustme.CA, tls_client_certificate: trustme.LeafCert) -> ssl.SSLContext:
118+
def tls_client_ssl_context(
119+
tls_certificate_authority: trustme.CA, tls_client_certificate: trustme.LeafCert
120+
) -> ssl.SSLContext:
118121
ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
119122
tls_certificate_authority.configure_trust(ssl_ctx)
120123

121124
# Load the client certificate chain into the SSL context
122125
with tls_client_certificate.private_key_and_cert_chain_pem.tempfile() as client_cert_pem:
123126
ssl_ctx.load_cert_chain(certfile=client_cert_pem)
124-
125127

126128
return ssl_ctx
127129

128-
@pytest.fixture
129-
def tls_client_certificate_pem_path(tls_client_certificate: trustme.LeafCert):
130-
private_key_and_cert_chain = tls_client_certificate.private_key_and_cert_chain_pem
131-
with private_key_and_cert_chain.tempfile() as client_cert_pem:
132-
yield client_cert_pem
133-
134130

135131
@pytest.fixture(scope="package")
136132
def reload_directory_structure(tmp_path_factory: pytest.TempPathFactory):
@@ -283,8 +279,3 @@ def ws_protocol_cls(request: pytest.FixtureRequest):
283279
)
284280
def http_protocol_cls(request: pytest.FixtureRequest):
285281
return import_from_string(request.param)
286-
@pytest.fixture
287-
def tls_client_certificate_pem_path(tls_client_certificate: trustme.LeafCert):
288-
private_key_and_cert_chain = tls_client_certificate.private_key_and_cert_chain_pem
289-
with private_key_and_cert_chain.tempfile() as client_cert_pem:
290-
yield client_cert_pem

tests/test_ssl.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import httpx
44
import pytest
5-
65
from cryptography import x509
76

87
from tests.utils import run_server
@@ -46,8 +45,6 @@ async def test_run(
4645
],
4746
indirect=["tls_client_certificate"],
4847
)
49-
50-
5148
@pytest.mark.anyio
5249
async def test_run_httptools_client_cert(
5350
tls_client_ssl_context,
@@ -59,11 +56,13 @@ async def test_run_httptools_client_cert(
5956
async def app(scope, receive, send):
6057
assert scope["type"] == "http"
6158
assert len(scope["extensions"]["tls"]["client_cert_chain"]) >= 1
62-
cert = x509.load_pem_x509_certificate(scope["extensions"]["tls"]["client_cert_chain"][0].encode('utf-8'))
59+
cert = x509.load_pem_x509_certificate(scope["extensions"]["tls"]["client_cert_chain"][0].encode("utf-8"))
6360
assert cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value == expected_common_name
64-
cipher_suites = [cipher['name'] for cipher in ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER).get_ciphers()]
61+
cipher_suites = [cipher["name"] for cipher in ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER).get_ciphers()]
6562
assert scope["extensions"]["tls"]["cipher_suite"] in cipher_suites
66-
assert (scope["extensions"]["tls"]["tls_version"].startswith("TLSv") or scope["extensions"]["tls"]["tls_version"].startswith("SSLv"))
63+
assert scope["extensions"]["tls"]["tls_version"].startswith("TLSv") or scope["extensions"]["tls"][
64+
"tls_version"
65+
].startswith("SSLv")
6766

6867
await send({"type": "http.response.start", "status": 204, "headers": []})
6968
await send({"type": "http.response.body", "body": b"", "more_body": False})

uvicorn/protocols/http/h11_impl.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
service_unavailable,
2828
)
2929
from uvicorn.protocols.utils import (
30+
TLSInfo,
3031
get_client_addr,
3132
get_local_addr,
3233
get_path_with_query_string,
3334
get_remote_addr,
3435
get_tls_info,
3536
is_ssl,
3637
)
38+
from uvicorn.server import ServerState
3739

3840

3941
def _get_status_phrase(status_code: int) -> bytes:
@@ -89,7 +91,7 @@ def __init__(
8991
self.server: tuple[str, int] | None = None
9092
self.client: tuple[str, int] | None = None
9193
self.scheme: Literal["http", "https"] | None = None
92-
self.tls: dict[object, object] = {}
94+
self.tls: TLSInfo = TLSInfo()
9395

9496
# Per-request state
9597
self.scope: HTTPScope = None # type: ignore[assignment]
@@ -231,8 +233,8 @@ def handle_events(self) -> None:
231233
}
232234

233235
if self.config.is_ssl:
234-
self.scope["extensions"]["tls"] = self.tls
235-
236+
self.scope["extensions"]["tls"] = dict(self.tls)
237+
236238
if self._should_upgrade():
237239
self.handle_websocket_upgrade(event)
238240
return

uvicorn/protocols/http/httptools_impl.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@
3333
service_unavailable,
3434
)
3535
from uvicorn.protocols.utils import (
36+
TLSInfo,
3637
get_client_addr,
3738
get_local_addr,
3839
get_path_with_query_string,
3940
get_remote_addr,
4041
get_tls_info,
4142
is_ssl,
4243
)
44+
from uvicorn.server import ServerState
4345

4446
HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]')
4547
HEADER_VALUE_RE = re.compile(b"[\x00-\x08\x0a-\x1f\x7f]")
@@ -103,7 +105,7 @@ def __init__(
103105
self.client: tuple[str, int] | None = None
104106
self.scheme: Literal["http", "https"] | None = None
105107
self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
106-
self.tls: dict[object, object] = {}
108+
self.tls: TLSInfo = TLSInfo()
107109

108110
# Per-request state
109111
self.scope: HTTPScope = None # type: ignore[assignment]
@@ -124,7 +126,7 @@ def connection_made( # type: ignore[override]
124126
self.scheme = "https" if is_ssl(transport) else "http"
125127

126128
if self.config.is_ssl:
127-
self.tls = get_tls_info(transport,self.config)
129+
self.tls = get_tls_info(transport, self.config)
128130

129131
if self.logger.level <= TRACE_LOG_LEVEL:
130132
prefix = "%s:%d - " % self.client if self.client else ""
@@ -250,11 +252,11 @@ def on_message_begin(self) -> None:
250252
"root_path": self.root_path,
251253
"headers": self.headers,
252254
"state": self.app_state.copy(),
253-
"extensions": {},
255+
"extensions": {},
254256
}
255257

256258
if self.config.is_ssl:
257-
self.scope["extensions"]["tls"] = self.tls
259+
self.scope["extensions"]["tls"] = dict(self.tls)
258260

259261
# Parser callbacks
260262
def on_url(self, url: bytes) -> None:

uvicorn/protocols/utils.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import ssl
55
import urllib.parse
6-
import sys
6+
from typing import TypedDict
77

88
from uvicorn._types import WWWScope
99
from uvicorn.config import Config
@@ -59,16 +59,22 @@ def get_path_with_query_string(scope: WWWScope) -> str:
5959
return path_with_query_string
6060

6161

62+
class TLSInfo(TypedDict):
63+
server_cert: str | None
64+
client_cert_chain: list[str]
65+
tls_version: str | None
66+
cipher_suite: str | None
6267

63-
def get_tls_info(transport: asyncio.Transport, server_config: Config) -> dict[object, object]:
68+
69+
def get_tls_info(transport: asyncio.Transport, server_config: Config) -> TLSInfo:
6470
###
6571
# server_cert: Unable to set from transport information, need to set from server_config
66-
# client_cert_chain:
72+
# client_cert_chain:
6773
# tls_version:
6874
# cipher_suite:
6975
###
7076

71-
ssl_info: dict[object, object] = {
77+
ssl_info: TLSInfo = {
7278
"server_cert": None,
7379
"client_cert_chain": [],
7480
"tls_version": None,
@@ -79,12 +85,16 @@ def get_tls_info(transport: asyncio.Transport, server_config: Config) -> dict[ob
7985

8086
ssl_object = transport.get_extra_info("ssl_object")
8187
if ssl_object is not None:
82-
client_chain = ssl_object.get_verified_chain() if hasattr(ssl_object, "get_verified_chain") else [ssl_object.getpeercert(binary_form=True)]
88+
client_chain = (
89+
ssl_object.get_verified_chain()
90+
if hasattr(ssl_object, "get_verified_chain")
91+
else [ssl_object.getpeercert(binary_form=True)]
92+
)
8393
for cert in client_chain:
8494
if cert is not None:
8595
ssl_info["client_cert_chain"].append(ssl.DER_cert_to_PEM_cert(cert))
86-
96+
8797
ssl_info["tls_version"] = ssl_object.version()
8898
ssl_info["cipher_suite"] = ssl_object.cipher()[0] if ssl_object.cipher() else None
8999

90-
return ssl_info
100+
return ssl_info

0 commit comments

Comments
 (0)