diff options
Diffstat (limited to 'urllib3/contrib')
-rw-r--r-- | urllib3/contrib/ntlmpool.py | 2 | ||||
-rw-r--r-- | urllib3/contrib/pyopenssl.py | 193 |
2 files changed, 186 insertions, 9 deletions
diff --git a/urllib3/contrib/ntlmpool.py b/urllib3/contrib/ntlmpool.py index 277ee0b..b8cd933 100644 --- a/urllib3/contrib/ntlmpool.py +++ b/urllib3/contrib/ntlmpool.py @@ -33,7 +33,7 @@ class NTLMConnectionPool(HTTPSConnectionPool): def __init__(self, user, pw, authurl, *args, **kwargs): """ authurl is a random URL on the server that is protected by NTLM. - user is the Windows user, probably in the DOMAIN\username format. + user is the Windows user, probably in the DOMAIN\\username format. pw is the password for the user. """ super(NTLMConnectionPool, self).__init__(*args, **kwargs) diff --git a/urllib3/contrib/pyopenssl.py b/urllib3/contrib/pyopenssl.py index 5c4c6d8..d43bcd6 100644 --- a/urllib3/contrib/pyopenssl.py +++ b/urllib3/contrib/pyopenssl.py @@ -20,13 +20,13 @@ Now you can use :mod:`urllib3` as you normally would, and it will support SNI when the required modules are installed. ''' -from ndg.httpsclient.ssl_peer_verification import (ServerSSLCertVerification, - SUBJ_ALT_NAME_SUPPORT) +from ndg.httpsclient.ssl_peer_verification import SUBJ_ALT_NAME_SUPPORT from ndg.httpsclient.subj_alt_name import SubjectAltName import OpenSSL.SSL from pyasn1.codec.der import decoder as der_decoder from socket import _fileobject import ssl +from cStringIO import StringIO from .. import connectionpool from .. import util @@ -99,6 +99,172 @@ def get_subj_alt_name(peer_cert): return dns_name +class fileobject(_fileobject): + + def read(self, size=-1): + # Use max, disallow tiny reads in a loop as they are very inefficient. + # We never leave read() with any leftover data from a new recv() call + # in our internal buffer. + rbufsize = max(self._rbufsize, self.default_bufsize) + # Our use of StringIO rather than lists of string objects returned by + # recv() minimizes memory usage and fragmentation that occurs when + # rbufsize is large compared to the typical return value of recv(). + buf = self._rbuf + buf.seek(0, 2) # seek end + if size < 0: + # Read until EOF + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + while True: + try: + data = self._sock.recv(rbufsize) + except OpenSSL.SSL.WantReadError: + continue + if not data: + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = buf.tell() + if buf_len >= size: + # Already have size bytes in our buffer? Extract and return. + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO() + self._rbuf.write(buf.read()) + return rv + + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + while True: + left = size - buf_len + # recv() will malloc the amount of memory given as its + # parameter even though it often returns much less data + # than that. The returned data string is short lived + # as we copy it into a StringIO and free it. This avoids + # fragmentation issues on many platforms. + try: + data = self._sock.recv(left) + except OpenSSL.SSL.WantReadError: + continue + if not data: + break + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid buffer data copies when: + # - We have no data in our buffer. + # AND + # - Our call to recv returned exactly the + # number of bytes we were asked to read. + return data + if n == left: + buf.write(data) + del data # explicit free + break + assert n <= left, "recv(%d) returned %d bytes" % (left, n) + buf.write(data) + buf_len += n + del data # explicit free + #assert buf_len == buf.tell() + return buf.getvalue() + + def readline(self, size=-1): + buf = self._rbuf + buf.seek(0, 2) # seek end + if buf.tell() > 0: + # check if we already have it in our buffer + buf.seek(0) + bline = buf.readline(size) + if bline.endswith('\n') or len(bline) == size: + self._rbuf = StringIO() + self._rbuf.write(buf.read()) + return bline + del bline + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + buf.seek(0) + buffers = [buf.read()] + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + data = None + recv = self._sock.recv + while True: + try: + while data != "\n": + data = recv(1) + if not data: + break + buffers.append(data) + except OpenSSL.SSL.WantReadError: + continue + break + return "".join(buffers) + + buf.seek(0, 2) # seek end + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + while True: + try: + data = self._sock.recv(self._rbufsize) + except OpenSSL.SSL.WantReadError: + continue + if not data: + break + nl = data.find('\n') + if nl >= 0: + nl += 1 + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) + del data + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or \n or EOF seen, whichever comes first + buf.seek(0, 2) # seek end + buf_len = buf.tell() + if buf_len >= size: + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO() + self._rbuf.write(buf.read()) + return rv + self._rbuf = StringIO() # reset _rbuf. we consume it via buf. + while True: + try: + data = self._sock.recv(self._rbufsize) + except OpenSSL.SSL.WantReadError: + continue + if not data: + break + left = size - buf_len + # did we just receive a newline? + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + # save the excess data to _rbuf + self._rbuf.write(data[nl:]) + if buf_len: + buf.write(data[:nl]) + break + else: + # Shortcut. Avoid data copy through buf when returning + # a substring of our first recv(). + return data[:nl] + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid data copy through buf when + # returning exactly all of our first recv(). + return data + if n >= left: + buf.write(data[:left]) + self._rbuf.write(data[left:]) + break + buf.write(data) + buf_len += n + #assert buf_len == buf.tell() + return buf.getvalue() + + class WrappedSocket(object): '''API-compatibility wrapper for Python OpenSSL's Connection-class.''' @@ -106,8 +272,11 @@ class WrappedSocket(object): self.connection = connection self.socket = socket + def fileno(self): + return self.socket.fileno() + def makefile(self, mode, bufsize=-1): - return _fileobject(self.connection, mode, bufsize) + return fileobject(self.connection, mode, bufsize) def settimeout(self, timeout): return self.socket.settimeout(timeout) @@ -115,10 +284,14 @@ class WrappedSocket(object): def sendall(self, data): return self.connection.sendall(data) + def close(self): + return self.connection.shutdown() + def getpeercert(self, binary_form=False): x509 = self.connection.get_peer_certificate() + if not x509: - raise ssl.SSLError('') + return x509 if binary_form: return OpenSSL.crypto.dump_certificate( @@ -159,9 +332,13 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None, cnx = OpenSSL.SSL.Connection(ctx, sock) cnx.set_tlsext_host_name(server_hostname) cnx.set_connect_state() - try: - cnx.do_handshake() - except OpenSSL.SSL.Error as e: - raise ssl.SSLError('bad handshake', e) + while True: + try: + cnx.do_handshake() + except OpenSSL.SSL.WantReadError: + continue + except OpenSSL.SSL.Error as e: + raise ssl.SSLError('bad handshake', e) + break return WrappedSocket(cnx, sock) |