diff --git a/gevent/fileobject.py b/gevent/fileobject.py index ea01a108b..48ff331c3 100644 --- a/gevent/fileobject.py +++ b/gevent/fileobject.py @@ -3,6 +3,7 @@ import os from gevent.hub import get_hub from gevent.hub import integer_types +from gevent.hub import PY3 from gevent.socket import EBADF from gevent.os import _read, _write, ignored_errors from gevent.lock import Semaphore, DummySemaphore @@ -25,7 +26,7 @@ else: - from gevent.socket import _fileobject, _get_memory + from gevent.socket import _get_memory cancel_wait_ex = IOError(EBADF, 'File descriptor was closed in another greenlet') from gevent.os import make_nonblocking @@ -35,8 +36,6 @@ SocketAdapter__del__ = None noop = None - from types import UnboundMethodType - class NA(object): def __repr__(self): @@ -44,163 +43,337 @@ def __repr__(self): NA = NA() - class SocketAdapter(object): - """Socket-like API on top of a file descriptor. - - The main purpose of it is to re-use _fileobject to create proper cooperative file objects - from file descriptors on POSIX platforms. - """ - - def __init__(self, fileno, mode=None, close=True): - if not isinstance(fileno, integer_types): - raise TypeError('fileno must be int: %r' % fileno) - self._fileno = fileno - self._mode = mode or 'rb' - self._close = close - self._translate = 'U' in self._mode - make_nonblocking(fileno) - self._eat_newline = False - self.hub = get_hub() - io = self.hub.loop.io - self._read_event = io(fileno, 1) - self._write_event = io(fileno, 2) + if PY3: + from io import BufferedRandom + from io import BufferedReader + from io import BufferedWriter + from io import BytesIO + from io import DEFAULT_BUFFER_SIZE + from io import RawIOBase + from io import TextIOWrapper + from io import UnsupportedOperation + + class GreenFileDescriptorIO(RawIOBase): + def __init__(self, fileno, mode='r', closefd=True): + super().__init__() + self._closed = False + self._closefd = closefd + self._fileno = fileno + make_nonblocking(fileno) + self._readable = 'r' in mode + self._writable = 'w' in mode + self.hub = get_hub() + io = self.hub.loop.io + if self._readable: + self._read_event = io(fileno, 1) + else: + self._read_event = None + if self._writable: + self._write_event = io(fileno, 2) + else: + self._write_event = None - def __repr__(self): - if self._fileno is None: - return '<%s at 0x%x closed>' % (self.__class__.__name__, id(self)) - else: - args = (self.__class__.__name__, id(self), getattr(self, '_fileno', NA), getattr(self, '_mode', NA)) - return '<%s at 0x%x (%r, %r)>' % args - - def makefile(self, *args, **kwargs): - return _fileobject(self, *args, **kwargs) - - def fileno(self): - result = self._fileno - if result is None: - raise IOError(EBADF, 'Bad file descriptor (%s object is closed)' % self.__class__.__name) - return result - - def detach(self): - x = self._fileno - self._fileno = None - return x - - def close(self): - self.hub.cancel_wait(self._read_event, cancel_wait_ex) - self.hub.cancel_wait(self._write_event, cancel_wait_ex) - fileno = self._fileno - if fileno is not None: - self._fileno = None - if self._close: - os.close(fileno) + def readable(self): + return self._readable - def sendall(self, data): - fileno = self.fileno() - bytes_total = len(data) - bytes_written = 0 - while True: - try: - bytes_written += _write(fileno, _get_memory(data, bytes_written)) - except (IOError, OSError) as ex: - code = ex.args[0] - if code not in ignored_errors: - raise - sys.exc_clear() - if bytes_written >= bytes_total: + def writable(self): + return self._writable + + def fileno(self): + return self._fileno + + @property + def closed(self): + return self._closed + + def close(self): + if self._closed: return - self.hub.wait(self._write_event) + self.flush() + self._closed = True + if self._readable: + self.hub.cancel_wait(self._read_event, cancel_wait_ex) + if self._writable: + self.hub.cancel_wait(self._write_event, cancel_wait_ex) + fileno = self._fileno + if self._closefd: + self._fileno = None + os.close(fileno) - def recv(self, size): - while True: + def read(self, n=1): + if not self._readable: + raise UnsupportedOperation('readinto') + while True: + try: + return _read(self._fileno, n) + except (IOError, OSError) as ex: + if ex.args[0] not in ignored_errors: + raise + self.hub.wait(self._read_event) + + def readall(self): + ret = BytesIO() + while True: + data = self.read(DEFAULT_BUFFER_SIZE) + if not data: + break + ret.write(data) + return ret.getvalue() + + def write(self, b): + if not self._writable: + raise UnsupportedOperation('write') + while True: + try: + return _write(self._fileno, b) + except (IOError, OSError) as ex: + if ex.args[0] not in ignored_errors: + raise + self.hub.wait(self._write_event) + + class FileObjectPosix: + default_bufsize = 8192 + + def __init__(self, fobj, mode='rb', bufsize=-1, close=True): + if isinstance(fobj, integer_types): + fileno = fobj + fobj = None + else: + fileno = fobj.fileno() + if not isinstance(fileno, integer_types): + raise TypeError('fileno must be int: %r' % fileno) + + mode = (mode or 'rb').replace('b', '') + if 'U' in mode: + self._translate = True + mode = mode.replace('U', '') + else: + self._translate = False + assert len(mode) == 1, 'mode can only be [rb, rU, wb]' + + self._fobj = fobj + self._closed = False + self._close = close + + self.fileio = GreenFileDescriptorIO(fileno, mode, closefd=close) + + if bufsize < 0: + bufsize = self.default_bufsize + if mode == 'r': + if bufsize == 0: + bufsize = 1 + elif bufsize == 1: + bufsize = self.default_bufsize + self.io = BufferedReader(self.fileio, bufsize) + elif mode == 'w': + self.io = BufferedWriter(self.fileio, bufsize) + else: + # QQQ: not used + self.io = BufferedRandom(self.fileio, bufsize) + if self._translate: + self.io = TextIOWrapper(self.io) + + @property + def closed(self): + """True if the file is cloed""" + return self._closed + + def close(self): + if self._closed: + # make sure close() is only ran once when called concurrently + return + self._closed = True try: - data = _read(self.fileno(), size) - except (IOError, OSError) as ex: - code = ex.args[0] - if code not in ignored_errors: - raise - sys.exc_clear() + self.io.close() + self.fileio.close() + finally: + self._fobj = None + + def flush(self): + self.io.flush() + + def fileno(self): + return self.io.fileno() + + def write(self, data): + self.io.write(data) + + def writelines(self, list): + self.io.writelines(list) + + def read(self, size=-1): + return self.io.read(size) + + def readline(self, size=-1): + return self.io.readline(size) + + def readlines(self, sizehint=0): + return self.io.readlines(sizehint) + + def __iter__(self): + return self.io + + else: + from gevent.socket import _fileobject + from types import UnboundMethodType + + class SocketAdapter(object): + """Socket-like API on top of a file descriptor. + + The main purpose of it is to re-use _fileobject to create proper cooperative file objects + from file descriptors on POSIX platforms. + """ + + def __init__(self, fileno, mode=None, close=True): + if not isinstance(fileno, integer_types): + raise TypeError('fileno must be int: %r' % fileno) + self._fileno = fileno + self._mode = mode or 'rb' + self._close = close + self._translate = 'U' in self._mode + make_nonblocking(fileno) + self._eat_newline = False + self.hub = get_hub() + io = self.hub.loop.io + self._read_event = io(fileno, 1) + self._write_event = io(fileno, 2) + + def __repr__(self): + if self._fileno is None: + return '<%s at 0x%x closed>' % (self.__class__.__name__, id(self)) else: - if not self._translate or not data: - return data - if self._eat_newline: - self._eat_newline = False - if data.startswith('\n'): - data = data[1:] - if not data: - return self.recv(size) - if data.endswith('\r'): - self._eat_newline = True - return self._translate_newlines(data) - self.hub.wait(self._read_event) - - def _translate_newlines(self, data): - data = data.replace("\r\n", "\n") - data = data.replace("\r", "\n") - return data - - if not SocketAdapter__del__: - - def __del__(self, close=os.close): - fileno = self._fileno - if fileno is not None: - close(fileno) + args = (self.__class__.__name__, id(self), getattr(self, '_fileno', NA), getattr(self, '_mode', NA)) + return '<%s at 0x%x (%r, %r)>' % args + + def makefile(self, *args, **kwargs): + return _fileobject(self, *args, **kwargs) - if SocketAdapter__del__: - SocketAdapter.__del__ = UnboundMethodType(SocketAdapter__del__, None, SocketAdapter) + def fileno(self): + result = self._fileno + if result is None: + raise IOError(EBADF, 'Bad file descriptor (%s object is closed)' % self.__class__.__name) + return result - class FileObjectPosix(_fileobject): + def detach(self): + x = self._fileno + self._fileno = None + return x - def __init__(self, fobj, mode='rb', bufsize=-1, close=True): - if isinstance(fobj, integer_types): - fileno = fobj - fobj = None - else: - fileno = fobj.fileno() - sock = SocketAdapter(fileno, mode, close=close) - self._fobj = fobj - self._closed = False - _fileobject.__init__(self, sock, mode=mode, bufsize=bufsize, close=close) + def close(self): + self.hub.cancel_wait(self._read_event, cancel_wait_ex) + self.hub.cancel_wait(self._write_event, cancel_wait_ex) + fileno = self._fileno + if fileno is not None: + self._fileno = None + if self._close: + os.close(fileno) + + def sendall(self, data): + fileno = self.fileno() + bytes_total = len(data) + bytes_written = 0 + while True: + try: + bytes_written += _write(fileno, _get_memory(data, bytes_written)) + except (IOError, OSError) as ex: + code = ex.args[0] + if code not in ignored_errors: + raise + sys.exc_clear() + if bytes_written >= bytes_total: + return + self.hub.wait(self._write_event) + + def recv(self, size): + while True: + try: + data = _read(self.fileno(), size) + except (IOError, OSError) as ex: + code = ex.args[0] + if code not in ignored_errors: + raise + sys.exc_clear() + else: + if not self._translate or not data: + return data + if self._eat_newline: + self._eat_newline = False + if data.startswith('\n'): + data = data[1:] + if not data: + return self.recv(size) + if data.endswith('\r'): + self._eat_newline = True + return self._translate_newlines(data) + self.hub.wait(self._read_event) + + def _translate_newlines(self, data): + data = data.replace("\r\n", "\n") + data = data.replace("\r", "\n") + return data + + if not SocketAdapter__del__: + + def __del__(self, close=os.close): + fileno = self._fileno + if fileno is not None: + close(fileno) + + if SocketAdapter__del__: + SocketAdapter.__del__ = UnboundMethodType(SocketAdapter__del__, None, SocketAdapter) + + class FileObjectPosix(_fileobject): + + def __init__(self, fobj, mode='rb', bufsize=-1, close=True): + if isinstance(fobj, integer_types): + fileno = fobj + fobj = None + else: + fileno = fobj.fileno() + sock = SocketAdapter(fileno, mode, close=close) + self._fobj = fobj + self._closed = False + _fileobject.__init__(self, sock, mode=mode, bufsize=bufsize, close=close) + + def __repr__(self): + if self._sock is None: + return '<%s closed>' % self.__class__.__name__ + elif self._fobj is None: + return '<%s %s>' % (self.__class__.__name__, self._sock) + else: + return '<%s %s _fobj=%r>' % (self.__class__.__name__, self._sock, self._fobj) - def __repr__(self): - if self._sock is None: - return '<%s closed>' % self.__class__.__name__ - elif self._fobj is None: - return '<%s %s>' % (self.__class__.__name__, self._sock) - else: - return '<%s %s _fobj=%r>' % (self.__class__.__name__, self._sock, self._fobj) - - def close(self): - if self._closed: - # make sure close() is only ran once when called concurrently - # cannot rely on self._sock for this because we need to keep that until flush() is done - return - self._closed = True - sock = self._sock - if sock is None: - return - try: - self.flush() - finally: - if self._fobj is not None or not self._close: - sock.detach() - self._sock = None - self._fobj = None - - def __getattr__(self, item): - assert item != '_fobj' - if self._fobj is None: - raise FileObjectClosed - return getattr(self._fobj, item) - - if not noop: - - def __del__(self): - # disable _fileobject's __del__ - pass - - if noop: - FileObjectPosix.__del__ = UnboundMethodType(FileObjectPosix, None, noop) + def close(self): + if self._closed: + # make sure close() is only ran once when called concurrently + # cannot rely on self._sock for this because we need to keep that until flush() is done + return + self._closed = True + sock = self._sock + if sock is None: + return + try: + self.flush() + finally: + if self._fobj is not None or not self._close: + sock.detach() + self._sock = None + self._fobj = None + + def __getattr__(self, item): + assert item != '_fobj' + if self._fobj is None: + raise FileObjectClosed + return getattr(self._fobj, item) + + if not noop: + + def __del__(self): + # disable _fileobject's __del__ + pass + + if noop: + FileObjectPosix.__del__ = UnboundMethodType(FileObjectPosix, None, noop) class FileObjectThread(object): diff --git a/gevent/socket.py b/gevent/socket.py index 69ca438ce..9c32a39d7 100644 --- a/gevent/socket.py +++ b/gevent/socket.py @@ -38,7 +38,8 @@ 'socket', 'SocketType', 'fromfd', - 'socketpair'] + 'socketpair', + 'fromshare'] __dns__ = ['getaddrinfo', 'gethostbyname', @@ -73,13 +74,21 @@ 'getservbyport', 'getdefaulttimeout', 'setdefaulttimeout', + # Python 3: + 'CMSG_LEN', + 'CMSG_SPACE', + 'dup', + 'if_indextoname', + 'if_nameindex', + 'if_nametoindex', + 'sethostname', # Windows: 'errorTab'] import sys import time -from gevent.hub import get_hub, string_types, integer_types, text_type +from gevent.hub import get_hub, string_types, integer_types, text_type, PY3 from gevent.timeout import Timeout is_windows = sys.platform == 'win32' @@ -110,7 +119,6 @@ import _socket _realsocket = _socket.socket import socket as __socket__ -_fileobject = __socket__._fileobject for name in __imports__[:]: try: @@ -216,21 +224,8 @@ def _dummy(*args, **kwargs): timeout_default = object() -class socket(object): - - def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None): - if _sock is None: - self._sock = _realsocket(family, type, proto) - self.timeout = _socket.getdefaulttimeout() - else: - if hasattr(_sock, '_sock'): - self._sock = _sock._sock - self.timeout = getattr(_sock, 'timeout', False) - if self.timeout is False: - self.timeout = _socket.getdefaulttimeout() - else: - self._sock = _sock - self.timeout = _socket.getdefaulttimeout() +class _BaseSocket(object): + def __init__(self): self._sock.setblocking(0) fileno = self._sock.fileno() self.hub = get_hub() @@ -238,17 +233,7 @@ def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None): self._read_event = io(fileno, 1) self._write_event = io(fileno, 2) - def __repr__(self): - return '<%s at %s %s>' % (type(self).__name__, hex(id(self)), self._formatinfo()) - - def __str__(self): - return '<%s %s>' % (type(self).__name__, self._formatinfo()) - def _formatinfo(self): - try: - fileno = self.fileno() - except Exception as ex: - fileno = str(ex) try: sockname = self.getsockname() sockname = '%s:%s' % sockname @@ -259,13 +244,13 @@ def _formatinfo(self): peername = '%s:%s' % peername except Exception: peername = None - result = 'fileno=%s' % fileno + result = [] if sockname is not None: - result += ' sock=' + str(sockname) + result.append(', sock=' + str(sockname)) if peername is not None: - result += ' peer=' + str(peername) + result.append(', peer=' + str(peername)) if getattr(self, 'timeout', None) is not None: - result += ' timeout=' + str(self.timeout) + result.append(', timeout=' + str(self.timeout)) return result def _get_ref(self): @@ -277,7 +262,7 @@ def _set_ref(self, value): ref = property(_get_ref, _set_ref) - def _wait(self, watcher, timeout_exc=timeout('timed out')): + def _wait(self, watcher, timeout_exc=timeout('timed out'), timeout=None): """Block the current greenlet until *watcher* has pending events. If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed. @@ -286,10 +271,12 @@ def _wait(self, watcher, timeout_exc=timeout('timed out')): If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``. """ assert watcher.callback is None, 'This socket is already used by another greenlet: %r' % (watcher.callback, ) - if self.timeout is not None: - timeout = Timeout.start_new(self.timeout, timeout_exc, ref=False) + if timeout is None: + timeout = self.timeout else: - timeout = None + timeout = timeout[0] + if timeout is not None: + timeout = Timeout.start_new(timeout, timeout_exc, ref=False) try: self.hub.wait(watcher) finally: @@ -305,7 +292,8 @@ def accept(self): except error as ex: if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0: raise - sys.exc_clear() + if not PY3: + sys.exc_clear() self._wait(self._read_event) return socket(_sock=client_socket), address @@ -357,21 +345,6 @@ def connect_ex(self, address): else: raise # gaierror is not silented by connect_ex - def dup(self): - """dup() -> socket object - - Return a new socket object connected to the same system resource. - Note, that the new socket does not inherit the timeout.""" - return socket(_sock=self._sock) - - def makefile(self, mode='r', bufsize=-1): - # Two things to look out for: - # 1) Closing the original socket object should not close the - # socket (hence creating a new instance) - # 2) The resulting fileobject must keep the timeout in order - # to be compatible with the stdlib's socket.makefile. - return _fileobject(type(self)(_sock=self), mode, bufsize) - def recv(self, *args): sock = self._sock # keeping the reference so that fd is not closed during waiting while True: @@ -380,8 +353,9 @@ def recv(self, *args): except error as ex: if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0: raise - # QQQ without clearing exc_info test__refcount.test_clean_exit fails - sys.exc_clear() + # QQQ without clearing exc_info test__refcount.test_clean_exit fails + if not PY3: + sys.exc_clear() self._wait(self._read_event) def recvfrom(self, *args): @@ -392,7 +366,8 @@ def recvfrom(self, *args): except error as ex: if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0: raise - sys.exc_clear() + if not PY3: + sys.exc_clear() self._wait(self._read_event) def recvfrom_into(self, *args): @@ -403,7 +378,8 @@ def recvfrom_into(self, *args): except error as ex: if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0: raise - sys.exc_clear() + if not PY3: + sys.exc_clear() self._wait(self._read_event) def recv_into(self, *args): @@ -414,7 +390,8 @@ def recv_into(self, *args): except error as ex: if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0: raise - sys.exc_clear() + if not PY3: + sys.exc_clear() self._wait(self._read_event) def send(self, data, flags=0, timeout=timeout_default): @@ -426,19 +403,20 @@ def send(self, data, flags=0, timeout=timeout_default): except error as ex: if ex.args[0] != EWOULDBLOCK or timeout == 0.0: raise - sys.exc_clear() - self._wait(self._write_event) - try: - return sock.send(data, flags) - except error as ex2: - if ex2.args[0] == EWOULDBLOCK: - return 0 - raise + if not PY3: + sys.exc_clear() + self._wait(self._write_event) + try: + return sock.send(data, flags) + except error as ex2: + if ex2.args[0] == EWOULDBLOCK: + return 0 + raise def sendall(self, data, flags=0): if isinstance(data, text_type): data = data.encode() - # this sendall is also reused by gevent.ssl.SSLSocket subclass, + # this sendall is also reused by gevent.ssl.SSLSocket subclass, # so it should not call self._sock methods directly if self.timeout is None: data_sent = 0 @@ -463,20 +441,15 @@ def sendto(self, *args): except error as ex: if ex.args[0] != EWOULDBLOCK or timeout == 0.0: raise - sys.exc_clear() - self._wait(self._write_event) - try: - return sock.sendto(*args) - except error as ex2: - if ex2.args[0] == EWOULDBLOCK: - return 0 - raise - - def setblocking(self, flag): - if flag: - self.timeout = None - else: - self.timeout = 0.0 + if not PY3: + sys.exc_clear() + self._wait(self._write_event) + try: + return sock.sendto(*args) + except error as ex2: + if ex2.args[0] == EWOULDBLOCK: + return 0 + raise def settimeout(self, howlong): if howlong is not None: @@ -487,7 +460,7 @@ def settimeout(self, howlong): howlong = f() if howlong < 0.0: raise ValueError('Timeout value out of range') - self.timeout = howlong + return howlong def gettimeout(self): return self.timeout @@ -502,17 +475,175 @@ def shutdown(self, how): self.hub.cancel_wait(self._write_event, cancel_wait_ex) self._sock.shutdown(how) - family = property(lambda self: self._sock.family, doc="the socket family") - type = property(lambda self: self._sock.type, doc="the socket type") - proto = property(lambda self: self._sock.proto, doc="the socket protocol") - # delegate the functions that we haven't implemented to the real socket object +if PY3: + class socket(_BaseSocket, __socket__.socket): + __slots__ = ['__sock', '_timeout', 'hub', '_read_event', + '_write_event'] - _s = ("def %s(self, *args): return self._sock.%s(*args)\n\n" - "%s.__doc__ = _realsocket.%s.__doc__\n") - for _m in set(__socket__._socketmethods) - set(locals()): - exec(_s % (_m, _m, _m, _m)) - del _m, _s + def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, + fileno=None, _sock=None): + self.__sock = None + if _sock is None: + self._sock.__init__(family, type, proto, fileno) + self._timeout = _socket.getdefaulttimeout() + else: + self._sock.__init__(_sock.family, _sock.type, _sock.proto, + _sock.detach()) + if hasattr(_sock, '_sock'): + self._timeout = getattr(_sock, 'timeout', False) + if self._timeout is False: + self._timeout = _socket.getdefaulttimeout() + else: + self._timeout = _socket.getdefaulttimeout() + super().__init__() + + def __repr__(self): + return '<%s%s>' % (self._sock.__repr__()[1:-1], + ', '.join(self._formatinfo())) + + def __str__(self): + return '<%s%s>' % (self._sock.__str__()[1:-1], + ', '.join(self._formatinfo())) + + @property + def _sock(self): + return self.__sock or super(_BaseSocket, self) + + @_sock.setter + def _sock(self, value): + self.__sock = value + + @property + def timeout(self): + return self._timeout + + def settimeout(self, howlong): + self._timeout = super().settimeout(howlong) + + def _real_close(self, _ss=_socket.socket, _closedsocket=_closedsocket, + cancel_wait_ex=cancel_wait_ex): + # This function should not reference any globals. See issue #808164 + super().close(_closedsocket, cancel_wait_ex) + super()._real_close(_ss) + + def close(self, _closedsocket=_closedsocket, + cancel_wait_ex=cancel_wait_ex): + self._sock.close() + + def detach(self): + ret = self._sock.detach() + self.close() + return ret + + def setblocking(self, flag): + if flag: + self._timeout = None + else: + self._timeout = 0.0 + + def recvmsg(self, *args): + sock = self._sock + while True: + try: + return sock.recvmsg(*args) + except error as ex: + if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0: + raise + self._wait(self._read_event) + + def recvmsg_into(self, *args): + sock = self._sock + while True: + try: + return sock.recvmsg_into(*args) + except error as ex: + if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0: + raise + self._wait(self._read_event) + + def sendmsg(self, *args): + sock = self._sock + while True: + try: + return sock.sendmsg(*args) + except error as ex: + if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0: + raise + self._wait(self._read_event) +else: + _fileobject = __socket__._fileobject + + class socket(_BaseSocket): + + def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, + _sock=None): + if _sock is None: + self._sock = _realsocket(family, type, proto) + self.timeout = _socket.getdefaulttimeout() + else: + if hasattr(_sock, '_sock'): + self._sock = _sock._sock + self.timeout = getattr(_sock, 'timeout', False) + if self.timeout is False: + self.timeout = _socket.getdefaulttimeout() + else: + self._sock = _sock + self.timeout = _socket.getdefaulttimeout() + super(socket, self).__init__() + + def __repr__(self): + return '<%s at %s %s>' % (type(self).__name__, hex(id(self)), + ' '.join(self._formatinfo())) + + def __str__(self): + return '<%s %s>' % (type(self).__name__, + ' '.join(self._formatinfo())) + + def _formatinfo(self): + try: + fileno = self.fileno() + except Exception as ex: + fileno = str(ex) + result = ['fileno=%s' % fileno] + result.extend(super(socket, self)._formatinfo()) + return result + + def dup(self): + """dup() -> socket object + + Return a new socket object connected to the same system resource. + Note, that the new socket does not inherit the timeout.""" + return socket(_sock=self._sock) + + def makefile(self, mode='r', bufsize=-1): + # Two things to look out for: + # 1) Closing the original socket object should not close the + # socket (hence creating a new instance) + # 2) The resulting fileobject must keep the timeout in order + # to be compatible with the stdlib's socket.makefile. + return _fileobject(type(self)(_sock=self), mode, bufsize) + + def setblocking(self, flag): + if flag: + self.timeout = None + else: + self.timeout = 0.0 + + def settimeout(self, howlong): + self.timeout = super(socket, self).settimeout(howlong) + + family = property(lambda self: self._sock.family, doc="the socket family") + type = property(lambda self: self._sock.type, doc="the socket type") + proto = property(lambda self: self._sock.proto, doc="the socket protocol") + + # delegate the functions that we haven't implemented to the real socket object + + _s = ("def %s(self, *args): return self._sock.%s(*args)\n\n" + "%s.__doc__ = _realsocket.%s.__doc__\n") + for _m in set(__socket__._socketmethods) - set(locals()) - set(dir(_BaseSocket)): + exec(_s % (_m, _m, _m, _m)) + del _m, _s SocketType = socket @@ -524,13 +655,33 @@ def socketpair(*args): else: __implements__.remove('socketpair') -if hasattr(_socket, 'fromfd'): +if not PY3 and hasattr(_socket, 'fromfd'): def fromfd(*args): return socket(_sock=_socket.fromfd(*args)) +elif PY3 and hasattr(__socket__, 'fromfd'): + def fromfd(fd, family, type, proto=0): + """ fromfd(fd, family, type[, proto]) -> socket object + + Create a socket object from a duplicate of the given file + descriptor. The remaining arguments are the same as for socket(). + """ + nfd = _socket.dup(fd) + return socket(family, type, proto, nfd) else: __implements__.remove('fromfd') +if hasattr(_realsocket, 'share'): + def fromshare(info): + """ fromshare(info) -> socket object + + Create a socket object from a the bytes object returned by + socket.share(pid). + """ + return socket(0, 0, 0, info) +else: + __implements__.remove('fromshare') + try: _GLOBAL_DEFAULT_TIMEOUT = __socket__._GLOBAL_DEFAULT_TIMEOUT @@ -570,9 +721,10 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=N # and the next bind() fails (see test__socket.TestCreateConnection) # that does not happen with regular sockets though, because _socket.socket.connect() is a built-in. # this is similar to "getnameinfo loses a reference" failure in test_socket.py - sys.exc_clear() - if sock is not None: - sock.close() + if not PY3: + sys.exc_clear() + if sock is not None: + sock.close() if err is not None: raise err else: @@ -603,8 +755,8 @@ def gethostbyname_ex(hostname): return get_hub().resolver.gethostbyname_ex(hostname) -def getaddrinfo(host, port, family=0, socktype=0, proto=0, flags=0): - return get_hub().resolver.getaddrinfo(host, port, family, socktype, proto, flags) +def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): + return get_hub().resolver.getaddrinfo(host, port, family, type, proto, flags) def gethostbyaddr(ip_address): diff --git a/gevent/ssl.py b/gevent/ssl.py index 357c1bff6..7c640d07f 100644 --- a/gevent/ssl.py +++ b/gevent/ssl.py @@ -17,15 +17,18 @@ import sys import errno -from gevent.socket import socket, _fileobject, timeout_default +from gevent.socket import socket, timeout_default from gevent.socket import error as socket_error +from gevent.hub import integer_types +from gevent.hub import PY3 from gevent.hub import string_types __implements__ = ['SSLSocket', 'wrap_socket', 'get_server_certificate', - 'sslwrap_simple'] + 'sslwrap_simple', + 'SSLContext'] __imports__ = ['SSLError', 'RAND_status', @@ -34,7 +37,18 @@ 'cert_time_to_seconds', 'get_protocol_name', 'DER_cert_to_PEM_cert', - 'PEM_cert_to_DER_cert'] + 'PEM_cert_to_DER_cert', + # Python 3 + 'CHANNEL_BINDING_TYPES', + 'SSLZeroReturnError', + 'SSLWantReadError', + 'SSLWantWriteError', + 'SSLSyscallError', + 'SSLEOFError', + 'CertificateError', + 'RAND_bytes', + 'RAND_pseudo_bytes', + 'match_hostname'] for name in __imports__[:]: try: @@ -46,84 +60,24 @@ for name in dir(__ssl__): if not name.startswith('_'): value = getattr(__ssl__, name) - if isinstance(value, (int, long, tuple)) or isinstance(value, string_types): + if (isinstance(value, integer_types) or isinstance(value, tuple) or + isinstance(value, string_types)): globals()[name] = value __imports__.append(name) del name, value -__all__ = __implements__ + __imports__ - - -class SSLSocket(socket): - - def __init__(self, sock, keyfile=None, certfile=None, - server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_SSLv23, ca_certs=None, - do_handshake_on_connect=True, - suppress_ragged_eofs=True, - ciphers=None): - socket.__init__(self, _sock=sock) - if certfile and not keyfile: - keyfile = certfile - # see if it's connected - try: - socket.getpeername(self) - except socket_error as e: - if e.args[0] != errno.ENOTCONN: - raise - # no, no connection yet - self._sslobj = None - else: - # yes, create the SSL object - if ciphers is None: - self._sslobj = _ssl.sslwrap(self._sock, server_side, - keyfile, certfile, - cert_reqs, ssl_version, ca_certs) - else: - self._sslobj = _ssl.sslwrap(self._sock, server_side, - keyfile, certfile, - cert_reqs, ssl_version, ca_certs, - ciphers) - if do_handshake_on_connect: - self.do_handshake() - self.keyfile = keyfile - self.certfile = certfile - self.cert_reqs = cert_reqs - self.ssl_version = ssl_version - self.ca_certs = ca_certs - self.ciphers = ciphers - self.do_handshake_on_connect = do_handshake_on_connect - self.suppress_ragged_eofs = suppress_ragged_eofs - self._makefile_refs = 0 - - def read(self, len=1024): - """Read up to LEN bytes and return them. - Return zero-length string on EOF.""" - while True: - try: - return self._sslobj.read(len) - except SSLError as ex: - if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: - return '' - elif ex.args[0] == SSL_ERROR_WANT_READ: - if self.timeout == 0.0: - raise - sys.exc_clear() - self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout) - elif ex.args[0] == SSL_ERROR_WANT_WRITE: - if self.timeout == 0.0: - raise - sys.exc_clear() - # note: using _SSLErrorReadTimeout rather than _SSLErrorWriteTimeout below is intentional - self._wait(self._write_event, timeout_exc=_SSLErrorReadTimeout) - else: - raise +class _BaseSSLSocket(socket): + def _checkClosed(self, msg=None): + # raise an exception here if you wish to check for spurious closes + pass def write(self, data): """Write DATA to the underlying SSL channel. Returns number of bytes of DATA actually transmitted.""" + + self._checkClosed() while True: try: return self._sslobj.write(data) @@ -131,30 +85,37 @@ def write(self, data): if ex.args[0] == SSL_ERROR_WANT_READ: if self.timeout == 0.0: raise - sys.exc_clear() - self._wait(self._read_event, timeout_exc=_SSLErrorWriteTimeout) + event = self._read_event + timeout_exc = _SSLErrorWriteTimeout elif ex.args[0] == SSL_ERROR_WANT_WRITE: if self.timeout == 0.0: raise - sys.exc_clear() - self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout) + event = self._write_event + timeout_exc = _SSLErrorWriteTimeout else: raise + if not PY3: + sys.exc_clear() + self._wait(event, timeout_exc=timeout_exc) def getpeercert(self, binary_form=False): """Returns a formatted version of the data in the certificate provided by the other end of the SSL channel. Return None if no certificate was provided, {} if a certificate was provided, but not validated.""" + + self._checkClosed() return self._sslobj.peer_certificate(binary_form) def cipher(self): + self._checkClosed() if not self._sslobj: return None else: return self._sslobj.cipher() def send(self, data, flags=0, timeout=timeout_default): + self._checkClosed() if timeout is timeout_default: timeout = self.timeout if self._sslobj: @@ -169,22 +130,28 @@ def send(self, data, flags=0, timeout=timeout_default): if x.args[0] == SSL_ERROR_WANT_READ: if self.timeout == 0.0: return 0 - sys.exc_clear() - self._wait(self._read_event) + event = self._read_event elif x.args[0] == SSL_ERROR_WANT_WRITE: if self.timeout == 0.0: return 0 - sys.exc_clear() - self._wait(self._write_event) + event = self._write_event else: raise + if not PY3: + sys.exc_clear() else: return v + self._wait(event) else: return socket.send(self, data, flags, timeout) + # is it possible for sendall() to send some data without encryption if another end shut down SSL? + def sendall(self, data, flags=0): + self._checkClosed() + return socket.sendall(self, data, flags) def sendto(self, *args): + self._checkClosed() if self._sslobj: raise ValueError("sendto not allowed on instances of %s" % self.__class__) @@ -192,6 +159,7 @@ def sendto(self, *args): return socket.sendto(self, *args) def recv(self, buflen=1024, flags=0): + self._checkClosed() if self._sslobj: if flags != 0: raise ValueError( @@ -202,35 +170,8 @@ def recv(self, buflen=1024, flags=0): else: return socket.recv(self, buflen, flags) - def recv_into(self, buffer, nbytes=None, flags=0): - if buffer and (nbytes is None): - nbytes = len(buffer) - elif nbytes is None: - nbytes = 1024 - if self._sslobj: - if flags != 0: - raise ValueError( - "non-zero flags not allowed in calls to recv_into() on %s" % - self.__class__) - while True: - try: - tmp_buffer = self.read(nbytes) - v = len(tmp_buffer) - buffer[:v] = tmp_buffer - return v - except SSLError as x: - if x.args[0] == SSL_ERROR_WANT_READ: - if self.timeout == 0.0: - raise - sys.exc_clear() - self._wait(self._read_event) - continue - else: - raise - else: - return socket.recv_into(self, buffer, nbytes, flags) - def recvfrom(self, *args): + self._checkClosed() if self._sslobj: raise ValueError("recvfrom not allowed on instances of %s" % self.__class__) @@ -238,6 +179,7 @@ def recvfrom(self, *args): return socket.recvfrom(self, *args) def recvfrom_into(self, *args): + self._checkClosed() if self._sslobj: raise ValueError("recvfrom_into not allowed on instances of %s" % self.__class__) @@ -245,6 +187,7 @@ def recvfrom_into(self, *args): return socket.recvfrom_into(self, *args) def pending(self): + self._checkClosed() if self._sslobj: return self._sslobj.pending() else: @@ -260,15 +203,18 @@ def _sslobj_shutdown(self): elif ex.args[0] == SSL_ERROR_WANT_READ: if self.timeout == 0.0: raise - sys.exc_clear() - self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout) + event = self._read_event + timeout_exc = _SSLErrorReadTimeout elif ex.args[0] == SSL_ERROR_WANT_WRITE: if self.timeout == 0.0: raise - sys.exc_clear() - self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout) + event = self._write_event + timeout_exc = _SSLErrorWriteTimeout else: raise + if not PY3: + sys.exc_clear() + self._wait(event, timeout_exc=timeout_exc) def unwrap(self): if self._sslobj: @@ -279,79 +225,438 @@ def unwrap(self): raise ValueError("No SSL wrapper around " + str(self)) def shutdown(self, how): + self._checkClosed() self._sslobj = None socket.shutdown(self, how) - def close(self): - if self._makefile_refs < 1: - self._sslobj = None - socket.close(self) - else: - self._makefile_refs -= 1 - - def do_handshake(self): + def do_handshake(self, block=False): """Perform a TLS/SSL handshake.""" while True: try: return self._sslobj.do_handshake() except SSLError as ex: if ex.args[0] == SSL_ERROR_WANT_READ: + timeout = None if self.timeout == 0.0: - raise - sys.exc_clear() - self._wait(self._read_event, timeout_exc=_SSLErrorHandshakeTimeout) + if block: + timeout = (None,) + else: + raise + event = self._read_event + timeout_exc = _SSLErrorHandshakeTimeout elif ex.args[0] == SSL_ERROR_WANT_WRITE: + timeout = None if self.timeout == 0.0: - raise + if block: + timeout = (None,) + else: + raise + event = self._write_event + timeout_exc = _SSLErrorHandshakeTimeout + else: + raise + if not PY3: sys.exc_clear() - self._wait(self._write_event, timeout_exc=_SSLErrorHandshakeTimeout) + self._wait(event, timeout_exc=timeout_exc, timeout=timeout) + + +if PY3: + class SSLContext(__ssl__.SSLContext): + """An SSLContext holds various SSL-related configuration options and + data, such as certificates and possibly a private key.""" + + def wrap_socket(self, sock, server_side=False, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname=None): + return SSLSocket(sock=sock, server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + server_hostname=server_hostname, + _context=self) + + class SSLSocket(_BaseSSLSocket): + def __init__(self, sock=None, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, + suppress_ragged_eofs=True, npn_protocols=None, + ciphers=None, server_hostname=None, + _context=None): + if _context: + self.context = _context + else: + if server_side and not certfile: + raise ValueError("certfile must be specified for server-side " + "operations") + if keyfile and not certfile: + raise ValueError("certfile must be specified") + if certfile and not keyfile: + keyfile = certfile + self.context = SSLContext(ssl_version) + self.context.verify_mode = cert_reqs + if ca_certs: + self.context.load_verify_locations(ca_certs) + if certfile: + self.context.load_cert_chain(certfile, keyfile) + if npn_protocols: + self.context.set_npn_protocols(npn_protocols) + if ciphers: + self.context.set_ciphers(ciphers) + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + if server_side and server_hostname: + raise ValueError("server_hostname can only be specified " + "in client mode") + self.server_side = server_side + self.server_hostname = server_hostname + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + connected = False + if sock is not None: + socket.__init__(self, + family=sock.family, + type=sock.type, + proto=sock.proto, + fileno=sock.fileno()) + self.settimeout(sock.gettimeout()) + # see if it's connected + try: + sock.getpeername() + except socket_error as e: + if e.errno != errno.ENOTCONN: + raise + else: + connected = True + sock.detach() + elif fileno is not None: + socket.__init__(self, fileno=fileno) + else: + socket.__init__(self, family=family, type=type, proto=proto) + + self._closed = False + self._sslobj = None + self._connected = connected + if connected: + # create the SSL object + try: + self._sslobj = self.context._wrap_socket(self, server_side, + server_hostname) + if do_handshake_on_connect: + timeout = self.gettimeout() + if timeout == 0.0: + # non-blocking + raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") + self.do_handshake() + + except socket_error as x: + self.close() + raise x + + def dup(self): + raise NotImplemented("Can't dup() %s instances" % + self.__class__.__name__) + + def read(self, len=0, buffer=None): + """Read up to LEN bytes and return them. + Return zero-length string on EOF.""" + + self._checkClosed() # QQQ: maybe check in while? + while True: + try: + if buffer is not None: + v = self._sslobj.read(len, buffer) + else: + v = self._sslobj.read(len or 1024) + return v + except SSLError as ex: + if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: + if buffer is not None: + return 0 + else: + return b'' + elif ex.args[0] == SSL_ERROR_WANT_READ: + if self.timeout == 0.0: + raise + event = self._read_event + timeout_exc = _SSLErrorReadTimeout + elif ex.args[0] == SSL_ERROR_WANT_WRITE: + if self.timeout == 0.0: + raise + # note: using _SSLErrorReadTimeout rather than _SSLErrorWriteTimeout below is intentional + event = self._write_event + timeout_exc = _SSLErrorReadTimeout + else: + raise + self._wait(event, timeout_exc=timeout_exc) + + def compression(self): + self._checkClosed() + if not self._sslobj: + return None + else: + return self._sslobj.compression() + + def sendmsg(self, *args, **kwargs): + # Ensure programs don't send data unencrypted if they try to + # use this method. + raise NotImplementedError("sendmsg not allowed on instances of %s" % + self.__class__) + + def recv_into(self, buffer, nbytes=None, flags=0): + self._checkClosed() + if buffer and (nbytes is None): + nbytes = len(buffer) + elif nbytes is None: + nbytes = 1024 + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv_into() on %s" % + self.__class__) + while True: + try: + return self.read(nbytes, buffer) + except SSLError as x: + if x.args[0] == SSL_ERROR_WANT_READ: + if self.timeout == 0.0: + raise + else: + raise + self._wait(self._read_event) + else: + return socket.recv_into(self, buffer, nbytes, flags) + + def recvmsg(self, *args, **kwargs): + raise NotImplementedError("recvmsg not allowed on instances of %s" % + self.__class__) + + def recvmsg_into(self, *args, **kwargs): + raise NotImplementedError("recvmsg_into not allowed on instances of " + "%s" % self.__class__) + + def _real_close(self): + self._sslobj = None + # self._closed = True + socket._real_close(self) + + def _real_connect(self, addr, connect_ex): + if self.server_side: + raise ValueError("can't connect in server-side mode") + # Here we assume that the socket is client-side, and not + # connected at the time of the call. We connect it, then wrap it. + if self._connected: + raise ValueError("attempt to connect already-connected SSLSocket!") + self._sslobj = self.context._wrap_socket(self, False, self.server_hostname) + try: + if connect_ex: + rc = socket.connect_ex(self, addr) else: + rc = None + socket.connect(self, addr) + if not rc: + if self.do_handshake_on_connect: + self.do_handshake() + self._connected = True + return rc + except socket_error: + self._sslobj = None + raise + + def connect(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + self._real_connect(addr, False) + + def connect_ex(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + return self._real_connect(addr, True) + + def accept(self): + """Accepts a new connection from a remote client, and returns + a tuple containing that new connection wrapped with a server-side + SSL channel, and the address of the remote client.""" + + newsock, addr = socket.accept(self) + newsock = self.context.wrap_socket( + newsock, + do_handshake_on_connect=self.do_handshake_on_connect, + suppress_ragged_eofs=self.suppress_ragged_eofs, + server_side=True) + return newsock, addr + + def get_channel_binding(self, cb_type="tls-unique"): + """Get channel binding data for current connection. Raise ValueError + if the requested `cb_type` is not supported. Return bytes of the data + or None if the data is not available (e.g. before the handshake). + """ + if cb_type not in CHANNEL_BINDING_TYPES: + raise ValueError("Unsupported channel binding type") + if cb_type != "tls-unique": + raise NotImplementedError( + "{0} channel binding type not implemented" + .format(cb_type)) + if self._sslobj is None: + return None + return self._sslobj.tls_unique_cb() + +else: + from gevent.socket import _fileobject + + __implements__.remove('SSLContext') + + class SSLSocket(_BaseSSLSocket): + def __init__(self, sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=CERT_NONE, + ssl_version=PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, suppress_ragged_eofs=True, + ciphers=None): + socket.__init__(self, _sock=sock) + + if certfile and not keyfile: + keyfile = certfile + # see if it's connected + try: + socket.getpeername(self) + except socket_error as e: + if e.args[0] != errno.ENOTCONN: raise + # no, no connection yet + self._sslobj = None + else: + # yes, create the SSL object + if ciphers is None: + self._sslobj = _ssl.sslwrap(self._sock, server_side, + keyfile, certfile, + cert_reqs, ssl_version, ca_certs) + else: + self._sslobj = _ssl.sslwrap(self._sock, server_side, + keyfile, certfile, + cert_reqs, ssl_version, ca_certs, + ciphers) + if do_handshake_on_connect: + self.do_handshake() + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.ciphers = ciphers + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + self._makefile_refs = 0 + + def read(self, len=1024): + """Read up to LEN bytes and return them. + Return zero-length string on EOF.""" + + self._checkClosed() + while True: + try: + return self._sslobj.read(len) + except SSLError as ex: + if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: + return '' + elif ex.args[0] == SSL_ERROR_WANT_READ: + if self.timeout == 0.0: + raise + sys.exc_clear() + self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout) + elif ex.args[0] == SSL_ERROR_WANT_WRITE: + if self.timeout == 0.0: + raise + sys.exc_clear() + # note: using _SSLErrorReadTimeout rather than _SSLErrorWriteTimeout below is intentional + self._wait(self._write_event, timeout_exc=_SSLErrorReadTimeout) + else: + raise - def connect(self, addr): - """Connects to remote ADDR, and then wraps the connection in - an SSL channel.""" - # Here we assume that the socket is client-side, and not - # connected at the time of the call. We connect it, then wrap it. - if self._sslobj: - raise ValueError("attempt to connect already-connected SSLSocket!") - socket.connect(self, addr) - if self.ciphers is None: - self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, - self.cert_reqs, self.ssl_version, - self.ca_certs) - else: - self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, - self.cert_reqs, self.ssl_version, - self.ca_certs, self.ciphers) - if self.do_handshake_on_connect: - self.do_handshake() - - def accept(self): - """Accepts a new connection from a remote client, and returns - a tuple containing that new connection wrapped with a server-side - SSL channel, and the address of the remote client.""" - newsock, addr = socket.accept(self) - return (SSLSocket(newsock._sock, - keyfile=self.keyfile, - certfile=self.certfile, - server_side=True, - cert_reqs=self.cert_reqs, - ssl_version=self.ssl_version, - ca_certs=self.ca_certs, - do_handshake_on_connect=self.do_handshake_on_connect, - suppress_ragged_eofs=self.suppress_ragged_eofs, - ciphers=self.ciphers), - addr) - - def makefile(self, mode='r', bufsize=-1): - """Make and return a file-like object that - works with the SSL connection. Just use the code - from the socket module.""" - self._makefile_refs += 1 - # close=True so as to decrement the reference count when done with - # the file-like object. - return _fileobject(self, mode, bufsize, close=True) + def recv_into(self, buffer, nbytes=None, flags=0): + self._checkClosed() + if buffer and (nbytes is None): + nbytes = len(buffer) + elif nbytes is None: + nbytes = 1024 + if self._sslobj: + if flags != 0: + raise ValueError( + "non-zero flags not allowed in calls to recv_into() on %s" % + self.__class__) + while True: + try: + tmp_buffer = self.read(nbytes) + v = len(tmp_buffer) + buffer[:v] = tmp_buffer + return v + except SSLError as x: + if x.args[0] == SSL_ERROR_WANT_READ: + if self.timeout == 0.0: + raise + sys.exc_clear() + self._wait(self._read_event) + continue + else: + raise + else: + return socket.recv_into(self, buffer, nbytes, flags) + + def close(self): + if self._makefile_refs < 1: + self._sslobj = None + socket.close(self) + else: + self._makefile_refs -= 1 + + def connect(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + # Here we assume that the socket is client-side, and not + # connected at the time of the call. We connect it, then wrap it. + if self._sslobj: + raise ValueError("attempt to connect already-connected SSLSocket!") + socket.connect(self, addr) + if self.ciphers is None: + self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, + self.cert_reqs, self.ssl_version, + self.ca_certs) + else: + self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, + self.cert_reqs, self.ssl_version, + self.ca_certs, self.ciphers) + if self.do_handshake_on_connect: + self.do_handshake() + + def accept(self): + """Accepts a new connection from a remote client, and returns + a tuple containing that new connection wrapped with a server-side + SSL channel, and the address of the remote client.""" + newsock, addr = socket.accept(self) + return (SSLSocket(newsock._sock, + keyfile=self.keyfile, + certfile=self.certfile, + server_side=True, + cert_reqs=self.cert_reqs, + ssl_version=self.ssl_version, + ca_certs=self.ca_certs, + do_handshake_on_connect=self.do_handshake_on_connect, + suppress_ragged_eofs=self.suppress_ragged_eofs, + ciphers=self.ciphers), + addr) + + def makefile(self, mode='r', bufsize=-1): + """Make and return a file-like object that + works with the SSL connection. Just use the code + from the socket module.""" + self._makefile_refs += 1 + # close=True so as to decrement the reference count when done with + # the file-like object. + return _fileobject(self, mode, bufsize, close=True) _SSLErrorReadTimeout = SSLError('The read operation timed out') @@ -392,8 +697,14 @@ def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): return DER_cert_to_PEM_cert(dercert) -def sslwrap_simple(sock, keyfile=None, certfile=None): - """A replacement for the old socket.ssl function. Designed - for compability with Python 2.5 and earlier. Will disappear in - Python 3.0.""" - return SSLSocket(sock, keyfile, certfile) +if PY3: + __implements__.remove('sslwrap_simple') +else: + def sslwrap_simple(sock, keyfile=None, certfile=None): + """A replacement for the old socket.ssl function. Designed + for compability with Python 2.5 and earlier. Will disappear in + Python 3.0.""" + return SSLSocket(sock, keyfile, certfile) + + +__all__ = __implements__ + __imports__