diff options
Diffstat (limited to 'urllib3/util.py')
-rw-r--r-- | urllib3/util.py | 142 |
1 files changed, 134 insertions, 8 deletions
diff --git a/urllib3/util.py b/urllib3/util.py index 8ec990b..544f9ed 100644 --- a/urllib3/util.py +++ b/urllib3/util.py @@ -1,5 +1,5 @@ # urllib3/util.py -# Copyright 2008-2012 Andrey Petrov and contributors (see CONTRIBUTORS.txt) +# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt) # # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -8,18 +8,32 @@ from base64 import b64encode from collections import namedtuple from socket import error as SocketError +from hashlib import md5, sha1 +from binascii import hexlify, unhexlify try: from select import poll, POLLIN -except ImportError: # `poll` doesn't exist on OSX and other platforms +except ImportError: # `poll` doesn't exist on OSX and other platforms poll = False try: from select import select - except ImportError: # `select` doesn't exist on AppEngine. + except ImportError: # `select` doesn't exist on AppEngine. select = False +try: # Test for SSL features + SSLContext = None + HAS_SNI = False + + import ssl + from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23 + from ssl import SSLContext # Modern SSL? + from ssl import HAS_SNI # Has SNI? +except ImportError: + pass + + from .packages import six -from .exceptions import LocationParseError +from .exceptions import LocationParseError, SSLError class Url(namedtuple('Url', ['scheme', 'auth', 'host', 'port', 'path', 'query', 'fragment'])): @@ -92,9 +106,9 @@ def parse_url(url): >>> parse_url('http://google.com/mail/') Url(scheme='http', host='google.com', port=None, path='/', ...) - >>> prase_url('google.com:80') + >>> parse_url('google.com:80') Url(scheme=None, host='google.com', port=80, path=None, ...) - >>> prase_url('/foo?bar') + >>> parse_url('/foo?bar') Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...) """ @@ -220,7 +234,7 @@ def make_headers(keep_alive=None, accept_encoding=None, user_agent=None, return headers -def is_connection_dropped(conn): +def is_connection_dropped(conn): # Platform-specific """ Returns True if the connection is dropped and should be closed. @@ -234,7 +248,7 @@ def is_connection_dropped(conn): if not sock: # Platform-specific: AppEngine return False - if not poll: # Platform-specific + if not poll: if not select: # Platform-specific: AppEngine return False @@ -250,3 +264,115 @@ def is_connection_dropped(conn): if fno == sock.fileno(): # Either data is buffered (bad), or the connection is dropped. return True + + +def resolve_cert_reqs(candidate): + """ + Resolves the argument to a numeric constant, which can be passed to + the wrap_socket function/method from the ssl module. + Defaults to :data:`ssl.CERT_NONE`. + If given a string it is assumed to be the name of the constant in the + :mod:`ssl` module or its abbrevation. + (So you can specify `REQUIRED` instead of `CERT_REQUIRED`. + If it's neither `None` nor a string we assume it is already the numeric + constant which can directly be passed to wrap_socket. + """ + if candidate is None: + return CERT_NONE + + if isinstance(candidate, str): + res = getattr(ssl, candidate, None) + if res is None: + res = getattr(ssl, 'CERT_' + candidate) + return res + + return candidate + + +def resolve_ssl_version(candidate): + """ + like resolve_cert_reqs + """ + if candidate is None: + return PROTOCOL_SSLv23 + + if isinstance(candidate, str): + res = getattr(ssl, candidate, None) + if res is None: + res = getattr(ssl, 'PROTOCOL_' + candidate) + return res + + return candidate + + +def assert_fingerprint(cert, fingerprint): + """ + Checks if given fingerprint matches the supplied certificate. + + :param cert: + Certificate as bytes object. + :param fingerprint: + Fingerprint as string of hexdigits, can be interspersed by colons. + """ + + # Maps the length of a digest to a possible hash function producing + # this digest. + hashfunc_map = { + 16: md5, + 20: sha1 + } + + fingerprint = fingerprint.replace(':', '').lower() + + digest_length, rest = divmod(len(fingerprint), 2) + + if rest or digest_length not in hashfunc_map: + raise SSLError('Fingerprint is of invalid length.') + + # We need encode() here for py32; works on py2 and p33. + fingerprint_bytes = unhexlify(fingerprint.encode()) + + hashfunc = hashfunc_map[digest_length] + + cert_digest = hashfunc(cert).digest() + + if not cert_digest == fingerprint_bytes: + raise SSLError('Fingerprints did not match. Expected "{0}", got "{1}".' + .format(hexlify(fingerprint_bytes), + hexlify(cert_digest))) + + +if SSLContext is not None: # Python 3.2+ + def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None, + ca_certs=None, server_hostname=None, + ssl_version=None): + """ + All arguments except `server_hostname` have the same meaning as for + :func:`ssl.wrap_socket` + + :param server_hostname: + Hostname of the expected certificate + """ + context = SSLContext(ssl_version) + context.verify_mode = cert_reqs + if ca_certs: + try: + context.load_verify_locations(ca_certs) + # Py32 raises IOError + # Py33 raises FileNotFoundError + except Exception as e: # Reraise as SSLError + raise SSLError(e) + if certfile: + # FIXME: This block needs a test. + context.load_cert_chain(certfile, keyfile) + if HAS_SNI: # Platform-specific: OpenSSL with enabled SNI + return context.wrap_socket(sock, server_hostname=server_hostname) + return context.wrap_socket(sock) + +else: # Python 3.1 and earlier + def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None, + ca_certs=None, server_hostname=None, + ssl_version=None): + return wrap_socket(sock, keyfile=keyfile, certfile=certfile, + ca_certs=ca_certs, cert_reqs=cert_reqs, + ssl_version=ssl_version) |