From 10ca0f337a33cba70a6f9c8b787548a6c942d06d Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 19 Oct 2016 15:27:18 -0400 Subject: [PATCH 1/2] Detach socket in create_server, create_connection, and other APIs --- asyncio/base_events.py | 17 ++++++++++++++--- tests/test_events.py | 3 ++- tests/test_streams.py | 2 +- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 648b9b9b..860f60e0 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -87,6 +87,12 @@ def _set_reuseport(sock): 'SO_REUSEPORT defined but not implemented.') +def _copy_and_detach_socket(sock): + new_sock = socket.socket(sock.family, sock.type, sock.proto, sock.fileno()) + sock.detach() + return new_sock + + # Linux's sock.type is a bitmask that can include extra info about socket. _SOCKET_TYPE_MASK = 0 if hasattr(socket, 'SOCK_NONBLOCK'): @@ -768,9 +774,11 @@ def create_connection(self, protocol_factory, host=None, port=None, *, raise OSError('Multiple exceptions: {}'.format( ', '.join(str(exc) for exc in exceptions))) - elif sock is None: - raise ValueError( - 'host and port was not specified and no sock specified') + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sock = _copy_and_detach_socket(sock) transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, server_hostname) @@ -827,6 +835,7 @@ def create_datagram_endpoint(self, protocol_factory, raise ValueError( 'socket modifier keyword arguments can not be used ' 'when sock is specified. ({})'.format(problems)) + sock = _copy_and_detach_socket(sock) sock.setblocking(False) r_addr = None else: @@ -1024,6 +1033,7 @@ def create_server(self, protocol_factory, host=None, port=None, else: if sock is None: raise ValueError('Neither host/port nor sock were specified') + sock = _copy_and_detach_socket(sock) sockets = [sock] server = Server(self, sockets) @@ -1045,6 +1055,7 @@ def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): This method is a coroutine. When completed, the coroutine returns a (transport, protocol) pair. """ + sock = _copy_and_detach_socket(sock) transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True) if self._debug: diff --git a/tests/test_events.py b/tests/test_events.py index 7df926f1..9df96081 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1240,11 +1240,12 @@ def connection_made(self, transport): sock_ob = socket.socket(type=socket.SOCK_STREAM) sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock_ob.bind(('0.0.0.0', 0)) + sock_ob_fd = sock_ob.fileno() f = self.loop.create_server(TestMyProto, sock=sock_ob) server = self.loop.run_until_complete(f) sock = server.sockets[0] - self.assertIs(sock, sock_ob) + self.assertEqual(sock.fileno(), sock_ob_fd) host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') diff --git a/tests/test_streams.py b/tests/test_streams.py index 35557c3c..f45d47b7 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -582,7 +582,7 @@ def start(self): asyncio.start_server(self.handle_client, sock=sock, loop=self.loop)) - return sock.getsockname() + return self.server.sockets[0].getsockname() def handle_client_callback(self, client_reader, client_writer): self.loop.create_task(self.handle_client(client_reader, From f1202eb441005e18df6be0826a5d090da05cd6e8 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 19 Oct 2016 16:07:23 -0400 Subject: [PATCH 2/2] Address Victor's comments --- asyncio/base_events.py | 8 ++++---- tests/test_events.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 860f60e0..640fdc6f 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -88,8 +88,8 @@ def _set_reuseport(sock): def _copy_and_detach_socket(sock): - new_sock = socket.socket(sock.family, sock.type, sock.proto, sock.fileno()) - sock.detach() + fd = sock.detach() + new_sock = socket.socket(sock.family, sock.type, sock.proto, fd) return new_sock @@ -835,8 +835,8 @@ def create_datagram_endpoint(self, protocol_factory, raise ValueError( 'socket modifier keyword arguments can not be used ' 'when sock is specified. ({})'.format(problems)) - sock = _copy_and_detach_socket(sock) sock.setblocking(False) + sock = _copy_and_detach_socket(sock) r_addr = None else: if not (local_addr or remote_addr): @@ -1055,7 +1055,7 @@ def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): This method is a coroutine. When completed, the coroutine returns a (transport, protocol) pair. """ - sock = _copy_and_detach_socket(sock) + transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True) if self._debug: diff --git a/tests/test_events.py b/tests/test_events.py index 9df96081..40d71102 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1246,6 +1246,7 @@ def connection_made(self, transport): server = self.loop.run_until_complete(f) sock = server.sockets[0] self.assertEqual(sock.fileno(), sock_ob_fd) + self.assertEqual(sock_ob.fileno(), -1) host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0')