diff --git a/hyper/common/connection.py b/hyper/common/connection.py index dee18d68..c2deb757 100644 --- a/hyper/common/connection.py +++ b/hyper/common/connection.py @@ -62,7 +62,8 @@ def __init__(self, self._port = port self._h1_kwargs = { 'secure': secure, 'ssl_context': ssl_context, - 'proxy_host': proxy_host, 'proxy_port': proxy_port + 'proxy_host': proxy_host, 'proxy_port': proxy_port, + 'enable_push': enable_push } self._h2_kwargs = { 'window_manager': window_manager, 'enable_push': enable_push, @@ -143,6 +144,22 @@ def get_response(self, *args, **kwargs): return self._conn.get_response(1) + def get_pushes(self, *args, **kwargs): + try: + return self._conn.get_pushes(*args, **kwargs) + except HTTPUpgrade as e: + assert e.negotiated == H2C_PROTOCOL + + self._conn = HTTP20Connection( + self._host, self._port, **self._h2_kwargs + ) + + self._conn._connect_upgrade(e.sock, True) + # stream id 1 is used by the upgrade request and response + # and is half-closed by the client + + return self._conn.get_pushes(*args, **kwargs) + # The following two methods are the implementation of the context manager # protocol. def __enter__(self): # pragma: no cover diff --git a/hyper/http11/connection.py b/hyper/http11/connection.py index 48bde2f0..52377401 100644 --- a/hyper/http11/connection.py +++ b/hyper/http11/connection.py @@ -58,6 +58,7 @@ class HTTP11Connection(object): """ version = HTTPVersion.http11 + _response = None def __init__(self, host, port=None, secure=None, ssl_context=None, proxy_host=None, proxy_port=None, **kwargs): @@ -78,6 +79,7 @@ def __init__(self, host, port=None, secure=None, ssl_context=None, # only send http upgrade headers for non-secure connection self._send_http_upgrade = not self.secure + self._enable_push = kwargs.get('enable_push') self.ssl_context = ssl_context self._sock = None @@ -104,6 +106,12 @@ def __init__(self, host, port=None, secure=None, ssl_context=None, #: the standard hyper parsing interface. self.parser = Parser() + def get_pushes(self, stream_id=None, capture_all=False): + """ + Dummy method to trigger h2c upgrade. + """ + self._get_response() + def connect(self): """ Connect to the server specified when the object was created. This is a @@ -188,6 +196,7 @@ def request(self, method, url, body=None, headers=None): # Next, send the request body. if body: self._send_body(body, body_type) + self._response = None return @@ -198,31 +207,39 @@ def get_response(self): This is an early beta, so the response object is pretty stupid. That's ok, we'll fix it later. """ - headers = HTTPHeaderMap() + resp = self._get_response() + self._response = None + return resp - response = None - while response is None: - # 'encourage' the socket to receive data. - self._sock.fill() - response = self.parser.parse_response(self._sock.buffer) + def _get_response(self): + if self._response is None: - for n, v in response.headers: - headers[n.tobytes()] = v.tobytes() + headers = HTTPHeaderMap() - self._sock.advance_buffer(response.consumed) + response = None + while response is None: + # 'encourage' the socket to receive data. + self._sock.fill() + response = self.parser.parse_response(self._sock.buffer) - if (response.status == 101 and + for n, v in response.headers: + headers[n.tobytes()] = v.tobytes() + + self._sock.advance_buffer(response.consumed) + + if (response.status == 101 and b'upgrade' in headers['connection'] and - H2C_PROTOCOL.encode('utf-8') in headers['upgrade']): - raise HTTPUpgrade(H2C_PROTOCOL, self._sock) - - return HTTP11Response( - response.status, - response.msg.tobytes(), - headers, - self._sock, - self - ) + H2C_PROTOCOL.encode('utf-8') in headers['upgrade']): + raise HTTPUpgrade(H2C_PROTOCOL, self._sock) + + self._response = HTTP11Response( + response.status, + response.msg.tobytes(), + headers, + self._sock, + self + ) + return self._response def _send_headers(self, method, url, headers): """ @@ -276,6 +293,10 @@ def _add_upgrade_headers(self, headers): # Settings header. http2_settings = SettingsFrame(0) http2_settings.settings[SettingsFrame.INITIAL_WINDOW_SIZE] = 65535 + if self._enable_push is not None: + http2_settings.settings[SettingsFrame.ENABLE_PUSH] = ( + int(self._enable_push) + ) encoded_settings = base64.urlsafe_b64encode( http2_settings.serialize_body() ) @@ -348,7 +369,7 @@ def _send_file_like_obj(self, fobj): Handles streaming a file-like object to the network. """ while True: - block = fobj.read(16*1024) + block = fobj.read(16 * 1024) if not block: break diff --git a/hyper/http20/connection.py b/hyper/http20/connection.py index 8b2a71e8..7223833d 100644 --- a/hyper/http20/connection.py +++ b/hyper/http20/connection.py @@ -114,6 +114,7 @@ def __init__(self, host, port=None, secure=None, window_manager=None, else: self.secure = False + self._delay_recv = False self._enable_push = enable_push self.ssl_context = ssl_context @@ -313,6 +314,9 @@ def get_response(self, stream_id=None): get a response. :returns: A :class:`HTTP20Response ` object. """ + if self._delay_recv: + self._recv_cb() + self._delay_recv = False stream = self._get_stream(stream_id) return HTTP20Response(stream.getheaders(), stream) @@ -384,7 +388,7 @@ def connect(self): self._send_preamble() - def _connect_upgrade(self, sock): + def _connect_upgrade(self, sock, no_recv=False): """ Called by the generic HTTP connection when we're being upgraded. Locks in a new socket and places the backing state machine into an upgrade @@ -405,7 +409,10 @@ def _connect_upgrade(self, sock): s = self._new_stream(local_closed=True) self.recent_stream = s - self._recv_cb() + if no_recv: # To delay I/O operation + self._delay_recv = True + else: + self._recv_cb() def _send_preamble(self): """ diff --git a/test/test_abstraction.py b/test/test_abstraction.py index cd0e0645..7c2cad1a 100644 --- a/test/test_abstraction.py +++ b/test/test_abstraction.py @@ -19,6 +19,7 @@ def test_h1_kwargs(self): 'proxy_host': False, 'proxy_port': False, 'other_kwarg': True, + 'enable_push': True, } def test_h2_kwargs(self): diff --git a/test/test_hyper.py b/test/test_hyper.py index 6a18d592..a1faf89f 100644 --- a/test/test_hyper.py +++ b/test/test_hyper.py @@ -9,6 +9,7 @@ PingFrame, FRAME_MAX_ALLOWED_LEN ) from hpack.hpack_compat import Encoder +from hyper import HTTPConnection from hyper.http20.connection import HTTP20Connection from hyper.http20.response import HTTP20Response, HTTP20Push from hyper.http20.exceptions import ConnectionError, StreamResetError @@ -731,8 +732,8 @@ def add_data_frame(self, stream_id, data, end_stream=False): frame.flags.add('END_STREAM') self.frames.append(frame) - def request(self): - self.conn = HTTP20Connection('www.google.com', enable_push=True) + def request(self, enable_push=True): + self.conn = HTTP20Connection('www.google.com', enable_push=enable_push) self.conn._sock = DummySocket() self.conn._sock.buffer = BytesIO( b''.join([frame.serialize() for frame in self.frames]) @@ -934,13 +935,13 @@ def test_reset_pushed_streams_when_push_disabled(self): 1, [(':status', '200'), ('content-type', 'text/html')] ) - self.request() - self.conn._enable_push = False + self.request(enable_push=False) self.conn.get_response() f = RstStreamFrame(2) f.error_code = 7 - assert self.conn._sock.queue[-1] == f.serialize() + print(self.conn._sock.queue) + assert self.conn._sock.queue[-1].endswith(f.serialize()) def test_pushed_requests_ignore_unexpected_headers(self): headers = HTTPHeaderMap([ @@ -956,7 +957,29 @@ def test_pushed_requests_ignore_unexpected_headers(self): assert p.request_headers == HTTPHeaderMap([('no', 'no')]) +class TestUpgradingPush(TestServerPush): + http101 = (b"HTTP/1.1 101 Switching Protocols\r\n" + b"Connection: upgrade\r\n" + b"Upgrade: h2c\r\n" + b"\r\n") + + def setup_method(self, method): + self.frames = [SettingsFrame(0)] # Server-side preface + self.encoder = Encoder() + self.conn = None + + def request(self, enable_push=True): + self.conn = HTTPConnection('www.google.com', enable_push=enable_push) + self.conn._conn._sock = DummySocket() + self.conn._conn._sock.buffer = BytesIO( + self.http101 + b''.join([frame.serialize() + for frame in self.frames]) + ) + self.conn.request('GET', '/') + + class TestResponse(object): + def test_status_is_stripped_from_headers(self): headers = HTTPHeaderMap([(':status', '200')]) resp = HTTP20Response(headers, None) diff --git a/test/test_integration.py b/test/test_integration.py index 96a8af76..f78107aa 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -12,6 +12,7 @@ import hyper import hyper.http11.connection import pytest +from contextlib import contextmanager from mock import patch from h2.frame_buffer import FrameBuffer from hyper.compat import ssl @@ -64,17 +65,30 @@ def frame_buffer(): return buffer +@contextmanager +def reusable_frame_buffer(buffer): + # FrameBuffer does not return new iterator for iteration. + data = buffer.data + yield buffer + buffer.data = data + + def receive_preamble(sock): # Receive the HTTP/2 'preamble'. - first = sock.recv(65535) + client_preface = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' + timeout = time.time() + 5 + got = b'' + while len(got) < len(client_preface) and time.time() < timeout: + got += sock.recv(len(client_preface) - len(got)) + + assert got == client_preface, "client preface mismatch" - # Work around some bugs: if the first message received was only the PRI - # string, aim to receive a settings frame as well. - if len(first) <= len(b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'): - sock.recv(65535) + # Send server side HTTP/2 preface sock.send(SettingsFrame(0).serialize()) - sock.recv(65535) - return + # Drain to let the client proceed. + # Note that in the lower socket level, this method is not + # just doing "receive". + return sock.recv(65535) @patch('hyper.http20.connection.H2_NPN_PROTOCOLS', PROTOCOLS) @@ -138,7 +152,7 @@ def socket_handler(listener): self._start_server(socket_handler) conn = self.get_connection() conn.connect() - send_event.wait() + send_event.wait(5) # Get the chunk of data after the preamble and decode it into frames. # We actually expect two, but only the second one contains ENABLE_PUSH. @@ -242,7 +256,7 @@ def socket_handler(listener): f = SettingsFrame(0) sock.send(f.serialize()) - send_event.wait() + send_event.wait(5) sock.recv(65535) sock.close() @@ -260,6 +274,7 @@ def socket_handler(listener): def test_closed_responses_remove_their_streams_from_conn(self): self.set_up() + req_event = threading.Event() recv_event = threading.Event() def socket_handler(listener): @@ -270,6 +285,8 @@ def socket_handler(listener): receive_preamble(sock) sock.recv(65535) + # Wait for request + req_event.wait(5) # Now, send the headers for the response. f = build_headers_frame([(':status', '200')]) f.stream_id = 1 @@ -282,6 +299,7 @@ def socket_handler(listener): self._start_server(socket_handler) conn = self.get_connection() conn.request('GET', '/') + req_event.set() resp = conn.get_response() # Close the response. @@ -296,6 +314,7 @@ def socket_handler(listener): def test_receiving_responses_with_no_body(self): self.set_up() + req_event = threading.Event() recv_event = threading.Event() def socket_handler(listener): @@ -306,6 +325,8 @@ def socket_handler(listener): receive_preamble(sock) sock.recv(65535) + # Wait for request + req_event.wait(5) # Now, send the headers for the response. This response has no body f = build_headers_frame( [(':status', '204'), ('content-length', '0')] @@ -321,6 +342,7 @@ def socket_handler(listener): self._start_server(socket_handler) conn = self.get_connection() conn.request('GET', '/') + req_event.set() resp = conn.get_response() # Confirm the status code. @@ -338,6 +360,7 @@ def socket_handler(listener): def test_receiving_trailers(self): self.set_up() + req_event = threading.Event() recv_event = threading.Event() def socket_handler(listener): @@ -350,6 +373,8 @@ def socket_handler(listener): receive_preamble(sock) sock.recv(65535) + # Wait for request + req_event.wait(5) # Now, send the headers for the response. f = build_headers_frame( [(':status', '200'), ('content-length', '14')], @@ -372,12 +397,13 @@ def socket_handler(listener): sock.send(f.serialize()) # Wait for the message from the main thread. - recv_event.set() + recv_event.wait(5) sock.close() self._start_server(socket_handler) conn = self.get_connection() conn.request('GET', '/') + req_event.set() resp = conn.get_response() # Confirm the status code. @@ -396,13 +422,14 @@ def socket_handler(listener): assert len(resp.trailers) == 2 # Awesome, we're done now. - recv_event.wait(5) + recv_event.set() self.tear_down() def test_receiving_trailers_before_reading(self): self.set_up() + req_event = threading.Event() recv_event = threading.Event() wait_event = threading.Event() @@ -416,6 +443,8 @@ def socket_handler(listener): receive_preamble(sock) sock.recv(65535) + # Wait for request + req_event.wait(5) # Now, send the headers for the response. f = build_headers_frame( [(':status', '200'), ('content-length', '14')], @@ -449,6 +478,7 @@ def socket_handler(listener): self._start_server(socket_handler) conn = self.get_connection() conn.request('GET', '/') + req_event.set() resp = conn.get_response() # Confirm the status code. @@ -647,6 +677,7 @@ def test_resetting_stream_with_frames_in_flight(self): """ self.set_up() + req_event = threading.Event() recv_event = threading.Event() def socket_handler(listener): @@ -657,6 +688,8 @@ def socket_handler(listener): receive_preamble(sock) sock.recv(65535) + # Wait for request + req_event.wait(5) # Now, send the headers for the response. This response has no # body. f = build_headers_frame( @@ -673,6 +706,7 @@ def socket_handler(listener): self._start_server(socket_handler) conn = self.get_connection() stream_id = conn.request('GET', '/') + req_event.set() # Now, trigger the RST_STREAM frame by closing the stream. conn._send_rst_frame(stream_id, 0) @@ -696,6 +730,7 @@ def test_stream_can_be_reset_multiple_times(self): """ self.set_up() + req_event = threading.Event() recv_event = threading.Event() def socket_handler(listener): @@ -706,6 +741,8 @@ def socket_handler(listener): receive_preamble(sock) sock.recv(65535) + # Wait for request + req_event.wait(5) # Now, send two RST_STREAM frames. for _ in range(0, 2): f = RstStreamFrame(1) @@ -718,6 +755,7 @@ def socket_handler(listener): self._start_server(socket_handler) conn = self.get_connection() conn.request('GET', '/') + req_event.set() # Now, eat the Rst frames. These should not cause an exception. conn._single_read() @@ -737,6 +775,7 @@ def socket_handler(listener): def test_read_chunked_http2(self): self.set_up() + req_event = threading.Event() recv_event = threading.Event() wait_event = threading.Event() @@ -748,6 +787,8 @@ def socket_handler(listener): receive_preamble(sock) sock.recv(65535) + # Wait for request + req_event.wait(5) # Now, send the headers for the response. This response has a body. f = build_headers_frame([(':status', '200')]) f.stream_id = 1 @@ -777,6 +818,7 @@ def socket_handler(listener): self._start_server(socket_handler) conn = self.get_connection() conn.request('GET', '/') + req_event.set() resp = conn.get_response() # Confirm the status code. @@ -805,6 +847,7 @@ def socket_handler(listener): def test_read_delayed(self): self.set_up() + req_event = threading.Event() recv_event = threading.Event() wait_event = threading.Event() @@ -816,6 +859,8 @@ def socket_handler(listener): receive_preamble(sock) sock.recv(65535) + # Wait for request + req_event.wait(5) # Now, send the headers for the response. This response has a body. f = build_headers_frame([(':status', '200')]) f.stream_id = 1 @@ -845,6 +890,7 @@ def socket_handler(listener): self._start_server(socket_handler) conn = self.get_connection() conn.request('GET', '/') + req_event.set() resp = conn.get_response() # Confirm the status code. @@ -958,6 +1004,8 @@ def socket_handler(listener): receive_preamble(sock) + # Wait for the message from the main thread. + send_event.wait() # Send the headers for the response. This response has no body. f = build_headers_frame( [(':status', '200'), ('content-length', '0')] @@ -965,9 +1013,6 @@ def socket_handler(listener): f.flags.add('END_STREAM') f.stream_id = 1 sock.sendall(f.serialize()) - - # Wait for the message from the main thread. - send_event.wait() sock.close() self._start_server(socket_handler) @@ -996,7 +1041,7 @@ def socket_handler(listener): data += sock.recv(65535) assert b'upgrade: h2c\r\n' in data - send_event.wait() + send_event.wait(5) # We need to send back a response. resp = ( @@ -1038,7 +1083,7 @@ class TestRequestsAdapter(SocketLevelTest): # This uses HTTP/2. h2 = True - def test_adapter_received_values(self, monkeypatch): + def test_adapter_received_values(self, monkeypatch, frame_buffer): self.set_up() # We need to patch the ssl_wrap_socket method to ensure that we @@ -1051,17 +1096,20 @@ def wrap(*args): monkeypatch.setattr(hyper.http11.connection, 'wrap_socket', wrap) - data = [] - send_event = threading.Event() - def socket_handler(listener): sock = listener.accept()[0] # Do the handshake: conn header, settings, send settings, recv ack. - receive_preamble(sock) + frame_buffer.add_data(receive_preamble(sock)) # Now expect some data. One headers frame. - data.append(sock.recv(65535)) + req_wait = True + while req_wait: + frame_buffer.add_data(sock.recv(65535)) + with reusable_frame_buffer(frame_buffer) as fr: + for f in fr: + if isinstance(f, HeadersFrame): + req_wait = False # Respond! h = HeadersFrame(1) @@ -1078,8 +1126,6 @@ def socket_handler(listener): d.data = b'1234567890' * 2 d.flags.add('END_STREAM') sock.send(d.serialize()) - - send_event.wait(5) sock.close() self._start_server(socket_handler) @@ -1093,11 +1139,9 @@ def socket_handler(listener): assert r.headers[b'Content-Type'] == b'not/real' assert r.content == b'1234567890' * 2 - send_event.set() - self.tear_down() - def test_adapter_sending_values(self, monkeypatch): + def test_adapter_sending_values(self, monkeypatch, frame_buffer): self.set_up() # We need to patch the ssl_wrap_socket method to ensure that we @@ -1110,17 +1154,20 @@ def wrap(*args): monkeypatch.setattr(hyper.http11.connection, 'wrap_socket', wrap) - data = [] - def socket_handler(listener): sock = listener.accept()[0] # Do the handshake: conn header, settings, send settings, recv ack. - receive_preamble(sock) + frame_buffer.add_data(receive_preamble(sock)) # Now expect some data. One headers frame and one data frame. - data.append(sock.recv(65535)) - data.append(sock.recv(65535)) + req_wait = True + while req_wait: + frame_buffer.add_data(sock.recv(65535)) + with reusable_frame_buffer(frame_buffer) as fr: + for f in fr: + if isinstance(f, DataFrame): + req_wait = False # Respond! h = HeadersFrame(1) @@ -1137,7 +1184,6 @@ def socket_handler(listener): d.data = b'1234567890' * 2 d.flags.add('END_STREAM') sock.send(d.serialize()) - sock.close() self._start_server(socket_handler) @@ -1152,11 +1198,10 @@ def socket_handler(listener): # Assert about the sent values. assert r.status_code == 200 - f = decode_frame(data[0]) - assert isinstance(f, HeadersFrame) + frames = list(frame_buffer) + assert isinstance(frames[-2], HeadersFrame) - f = decode_frame(data[1]) - assert isinstance(f, DataFrame) - assert f.data == b'hi there' + assert isinstance(frames[-1], DataFrame) + assert frames[-1].data == b'hi there' self.tear_down()