aboutsummaryrefslogtreecommitdiff
path: root/urllib3/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'urllib3/util.py')
-rw-r--r--urllib3/util.py142
1 files changed, 134 insertions, 8 deletions
diff --git a/urllib3/util.py b/urllib3/util.py
index 8ec990b..544f9ed 100644
--- a/urllib3/util.py
+++ b/urllib3/util.py
@@ -1,5 +1,5 @@
# urllib3/util.py
-# Copyright 2008-2012 Andrey Petrov and contributors (see CONTRIBUTORS.txt)
+# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt)
#
# This module is part of urllib3 and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -8,18 +8,32 @@
from base64 import b64encode
from collections import namedtuple
from socket import error as SocketError
+from hashlib import md5, sha1
+from binascii import hexlify, unhexlify
try:
from select import poll, POLLIN
-except ImportError: # `poll` doesn't exist on OSX and other platforms
+except ImportError: # `poll` doesn't exist on OSX and other platforms
poll = False
try:
from select import select
- except ImportError: # `select` doesn't exist on AppEngine.
+ except ImportError: # `select` doesn't exist on AppEngine.
select = False
+try: # Test for SSL features
+ SSLContext = None
+ HAS_SNI = False
+
+ import ssl
+ from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23
+ from ssl import SSLContext # Modern SSL?
+ from ssl import HAS_SNI # Has SNI?
+except ImportError:
+ pass
+
+
from .packages import six
-from .exceptions import LocationParseError
+from .exceptions import LocationParseError, SSLError
class Url(namedtuple('Url', ['scheme', 'auth', 'host', 'port', 'path', 'query', 'fragment'])):
@@ -92,9 +106,9 @@ def parse_url(url):
>>> parse_url('http://google.com/mail/')
Url(scheme='http', host='google.com', port=None, path='/', ...)
- >>> prase_url('google.com:80')
+ >>> parse_url('google.com:80')
Url(scheme=None, host='google.com', port=80, path=None, ...)
- >>> prase_url('/foo?bar')
+ >>> parse_url('/foo?bar')
Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...)
"""
@@ -220,7 +234,7 @@ def make_headers(keep_alive=None, accept_encoding=None, user_agent=None,
return headers
-def is_connection_dropped(conn):
+def is_connection_dropped(conn): # Platform-specific
"""
Returns True if the connection is dropped and should be closed.
@@ -234,7 +248,7 @@ def is_connection_dropped(conn):
if not sock: # Platform-specific: AppEngine
return False
- if not poll: # Platform-specific
+ if not poll:
if not select: # Platform-specific: AppEngine
return False
@@ -250,3 +264,115 @@ def is_connection_dropped(conn):
if fno == sock.fileno():
# Either data is buffered (bad), or the connection is dropped.
return True
+
+
+def resolve_cert_reqs(candidate):
+ """
+ Resolves the argument to a numeric constant, which can be passed to
+ the wrap_socket function/method from the ssl module.
+ Defaults to :data:`ssl.CERT_NONE`.
+ If given a string it is assumed to be the name of the constant in the
+ :mod:`ssl` module or its abbrevation.
+ (So you can specify `REQUIRED` instead of `CERT_REQUIRED`.
+ If it's neither `None` nor a string we assume it is already the numeric
+ constant which can directly be passed to wrap_socket.
+ """
+ if candidate is None:
+ return CERT_NONE
+
+ if isinstance(candidate, str):
+ res = getattr(ssl, candidate, None)
+ if res is None:
+ res = getattr(ssl, 'CERT_' + candidate)
+ return res
+
+ return candidate
+
+
+def resolve_ssl_version(candidate):
+ """
+ like resolve_cert_reqs
+ """
+ if candidate is None:
+ return PROTOCOL_SSLv23
+
+ if isinstance(candidate, str):
+ res = getattr(ssl, candidate, None)
+ if res is None:
+ res = getattr(ssl, 'PROTOCOL_' + candidate)
+ return res
+
+ return candidate
+
+
+def assert_fingerprint(cert, fingerprint):
+ """
+ Checks if given fingerprint matches the supplied certificate.
+
+ :param cert:
+ Certificate as bytes object.
+ :param fingerprint:
+ Fingerprint as string of hexdigits, can be interspersed by colons.
+ """
+
+ # Maps the length of a digest to a possible hash function producing
+ # this digest.
+ hashfunc_map = {
+ 16: md5,
+ 20: sha1
+ }
+
+ fingerprint = fingerprint.replace(':', '').lower()
+
+ digest_length, rest = divmod(len(fingerprint), 2)
+
+ if rest or digest_length not in hashfunc_map:
+ raise SSLError('Fingerprint is of invalid length.')
+
+ # We need encode() here for py32; works on py2 and p33.
+ fingerprint_bytes = unhexlify(fingerprint.encode())
+
+ hashfunc = hashfunc_map[digest_length]
+
+ cert_digest = hashfunc(cert).digest()
+
+ if not cert_digest == fingerprint_bytes:
+ raise SSLError('Fingerprints did not match. Expected "{0}", got "{1}".'
+ .format(hexlify(fingerprint_bytes),
+ hexlify(cert_digest)))
+
+
+if SSLContext is not None: # Python 3.2+
+ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
+ ca_certs=None, server_hostname=None,
+ ssl_version=None):
+ """
+ All arguments except `server_hostname` have the same meaning as for
+ :func:`ssl.wrap_socket`
+
+ :param server_hostname:
+ Hostname of the expected certificate
+ """
+ context = SSLContext(ssl_version)
+ context.verify_mode = cert_reqs
+ if ca_certs:
+ try:
+ context.load_verify_locations(ca_certs)
+ # Py32 raises IOError
+ # Py33 raises FileNotFoundError
+ except Exception as e: # Reraise as SSLError
+ raise SSLError(e)
+ if certfile:
+ # FIXME: This block needs a test.
+ context.load_cert_chain(certfile, keyfile)
+ if HAS_SNI: # Platform-specific: OpenSSL with enabled SNI
+ return context.wrap_socket(sock, server_hostname=server_hostname)
+ return context.wrap_socket(sock)
+
+else: # Python 3.1 and earlier
+ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
+ ca_certs=None, server_hostname=None,
+ ssl_version=None):
+ return wrap_socket(sock, keyfile=keyfile, certfile=certfile,
+ ca_certs=ca_certs, cert_reqs=cert_reqs,
+ ssl_version=ssl_version)