aboutsummaryrefslogtreecommitdiff
path: root/urllib3/contrib/pyopenssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'urllib3/contrib/pyopenssl.py')
-rw-r--r--urllib3/contrib/pyopenssl.py193
1 files changed, 185 insertions, 8 deletions
diff --git a/urllib3/contrib/pyopenssl.py b/urllib3/contrib/pyopenssl.py
index 5c4c6d8..d43bcd6 100644
--- a/urllib3/contrib/pyopenssl.py
+++ b/urllib3/contrib/pyopenssl.py
@@ -20,13 +20,13 @@ Now you can use :mod:`urllib3` as you normally would, and it will support SNI
when the required modules are installed.
'''
-from ndg.httpsclient.ssl_peer_verification import (ServerSSLCertVerification,
- SUBJ_ALT_NAME_SUPPORT)
+from ndg.httpsclient.ssl_peer_verification import SUBJ_ALT_NAME_SUPPORT
from ndg.httpsclient.subj_alt_name import SubjectAltName
import OpenSSL.SSL
from pyasn1.codec.der import decoder as der_decoder
from socket import _fileobject
import ssl
+from cStringIO import StringIO
from .. import connectionpool
from .. import util
@@ -99,6 +99,172 @@ def get_subj_alt_name(peer_cert):
return dns_name
+class fileobject(_fileobject):
+
+ def read(self, size=-1):
+ # Use max, disallow tiny reads in a loop as they are very inefficient.
+ # We never leave read() with any leftover data from a new recv() call
+ # in our internal buffer.
+ rbufsize = max(self._rbufsize, self.default_bufsize)
+ # Our use of StringIO rather than lists of string objects returned by
+ # recv() minimizes memory usage and fragmentation that occurs when
+ # rbufsize is large compared to the typical return value of recv().
+ buf = self._rbuf
+ buf.seek(0, 2) # seek end
+ if size < 0:
+ # Read until EOF
+ self._rbuf = StringIO() # reset _rbuf. we consume it via buf.
+ while True:
+ try:
+ data = self._sock.recv(rbufsize)
+ except OpenSSL.SSL.WantReadError:
+ continue
+ if not data:
+ break
+ buf.write(data)
+ return buf.getvalue()
+ else:
+ # Read until size bytes or EOF seen, whichever comes first
+ buf_len = buf.tell()
+ if buf_len >= size:
+ # Already have size bytes in our buffer? Extract and return.
+ buf.seek(0)
+ rv = buf.read(size)
+ self._rbuf = StringIO()
+ self._rbuf.write(buf.read())
+ return rv
+
+ self._rbuf = StringIO() # reset _rbuf. we consume it via buf.
+ while True:
+ left = size - buf_len
+ # recv() will malloc the amount of memory given as its
+ # parameter even though it often returns much less data
+ # than that. The returned data string is short lived
+ # as we copy it into a StringIO and free it. This avoids
+ # fragmentation issues on many platforms.
+ try:
+ data = self._sock.recv(left)
+ except OpenSSL.SSL.WantReadError:
+ continue
+ if not data:
+ break
+ n = len(data)
+ if n == size and not buf_len:
+ # Shortcut. Avoid buffer data copies when:
+ # - We have no data in our buffer.
+ # AND
+ # - Our call to recv returned exactly the
+ # number of bytes we were asked to read.
+ return data
+ if n == left:
+ buf.write(data)
+ del data # explicit free
+ break
+ assert n <= left, "recv(%d) returned %d bytes" % (left, n)
+ buf.write(data)
+ buf_len += n
+ del data # explicit free
+ #assert buf_len == buf.tell()
+ return buf.getvalue()
+
+ def readline(self, size=-1):
+ buf = self._rbuf
+ buf.seek(0, 2) # seek end
+ if buf.tell() > 0:
+ # check if we already have it in our buffer
+ buf.seek(0)
+ bline = buf.readline(size)
+ if bline.endswith('\n') or len(bline) == size:
+ self._rbuf = StringIO()
+ self._rbuf.write(buf.read())
+ return bline
+ del bline
+ if size < 0:
+ # Read until \n or EOF, whichever comes first
+ if self._rbufsize <= 1:
+ # Speed up unbuffered case
+ buf.seek(0)
+ buffers = [buf.read()]
+ self._rbuf = StringIO() # reset _rbuf. we consume it via buf.
+ data = None
+ recv = self._sock.recv
+ while True:
+ try:
+ while data != "\n":
+ data = recv(1)
+ if not data:
+ break
+ buffers.append(data)
+ except OpenSSL.SSL.WantReadError:
+ continue
+ break
+ return "".join(buffers)
+
+ buf.seek(0, 2) # seek end
+ self._rbuf = StringIO() # reset _rbuf. we consume it via buf.
+ while True:
+ try:
+ data = self._sock.recv(self._rbufsize)
+ except OpenSSL.SSL.WantReadError:
+ continue
+ if not data:
+ break
+ nl = data.find('\n')
+ if nl >= 0:
+ nl += 1
+ buf.write(data[:nl])
+ self._rbuf.write(data[nl:])
+ del data
+ break
+ buf.write(data)
+ return buf.getvalue()
+ else:
+ # Read until size bytes or \n or EOF seen, whichever comes first
+ buf.seek(0, 2) # seek end
+ buf_len = buf.tell()
+ if buf_len >= size:
+ buf.seek(0)
+ rv = buf.read(size)
+ self._rbuf = StringIO()
+ self._rbuf.write(buf.read())
+ return rv
+ self._rbuf = StringIO() # reset _rbuf. we consume it via buf.
+ while True:
+ try:
+ data = self._sock.recv(self._rbufsize)
+ except OpenSSL.SSL.WantReadError:
+ continue
+ if not data:
+ break
+ left = size - buf_len
+ # did we just receive a newline?
+ nl = data.find('\n', 0, left)
+ if nl >= 0:
+ nl += 1
+ # save the excess data to _rbuf
+ self._rbuf.write(data[nl:])
+ if buf_len:
+ buf.write(data[:nl])
+ break
+ else:
+ # Shortcut. Avoid data copy through buf when returning
+ # a substring of our first recv().
+ return data[:nl]
+ n = len(data)
+ if n == size and not buf_len:
+ # Shortcut. Avoid data copy through buf when
+ # returning exactly all of our first recv().
+ return data
+ if n >= left:
+ buf.write(data[:left])
+ self._rbuf.write(data[left:])
+ break
+ buf.write(data)
+ buf_len += n
+ #assert buf_len == buf.tell()
+ return buf.getvalue()
+
+
class WrappedSocket(object):
'''API-compatibility wrapper for Python OpenSSL's Connection-class.'''
@@ -106,8 +272,11 @@ class WrappedSocket(object):
self.connection = connection
self.socket = socket
+ def fileno(self):
+ return self.socket.fileno()
+
def makefile(self, mode, bufsize=-1):
- return _fileobject(self.connection, mode, bufsize)
+ return fileobject(self.connection, mode, bufsize)
def settimeout(self, timeout):
return self.socket.settimeout(timeout)
@@ -115,10 +284,14 @@ class WrappedSocket(object):
def sendall(self, data):
return self.connection.sendall(data)
+ def close(self):
+ return self.connection.shutdown()
+
def getpeercert(self, binary_form=False):
x509 = self.connection.get_peer_certificate()
+
if not x509:
- raise ssl.SSLError('')
+ return x509
if binary_form:
return OpenSSL.crypto.dump_certificate(
@@ -159,9 +332,13 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
cnx = OpenSSL.SSL.Connection(ctx, sock)
cnx.set_tlsext_host_name(server_hostname)
cnx.set_connect_state()
- try:
- cnx.do_handshake()
- except OpenSSL.SSL.Error as e:
- raise ssl.SSLError('bad handshake', e)
+ while True:
+ try:
+ cnx.do_handshake()
+ except OpenSSL.SSL.WantReadError:
+ continue
+ except OpenSSL.SSL.Error as e:
+ raise ssl.SSLError('bad handshake', e)
+ break
return WrappedSocket(cnx, sock)