From 653256249d44c67a0852d57a166948a9dc712ef4 Mon Sep 17 00:00:00 2001 From: SVN-Git Migration Date: Thu, 8 Oct 2015 13:41:28 -0700 Subject: Imported Upstream version 1.2.3 --- requests/__init__.py | 10 +- requests/adapters.py | 109 +++++++++++++--- requests/auth.py | 6 +- requests/cookies.py | 4 + requests/models.py | 47 ++++--- requests/packages/urllib3/__init__.py | 2 +- requests/packages/urllib3/_collections.py | 2 +- requests/packages/urllib3/connectionpool.py | 46 ++++--- requests/packages/urllib3/contrib/__init__.py | 0 requests/packages/urllib3/contrib/ntlmpool.py | 120 ++++++++++++++++++ requests/packages/urllib3/contrib/pyopenssl.py | 167 +++++++++++++++++++++++++ requests/packages/urllib3/exceptions.py | 28 +++-- requests/packages/urllib3/filepost.py | 2 +- requests/packages/urllib3/poolmanager.py | 14 ++- requests/packages/urllib3/request.py | 2 +- requests/packages/urllib3/response.py | 79 ++++++++---- requests/packages/urllib3/util.py | 50 +++++++- requests/sessions.py | 126 +++++++++---------- requests/status_codes.py | 1 + requests/structures.py | 89 +++++++++---- requests/utils.py | 51 +------- 21 files changed, 717 insertions(+), 238 deletions(-) create mode 100644 requests/packages/urllib3/contrib/__init__.py create mode 100644 requests/packages/urllib3/contrib/ntlmpool.py create mode 100644 requests/packages/urllib3/contrib/pyopenssl.py (limited to 'requests') diff --git a/requests/__init__.py b/requests/__init__.py index 1ea4aff..1af8d8e 100644 --- a/requests/__init__.py +++ b/requests/__init__.py @@ -42,12 +42,18 @@ is at . """ __title__ = 'requests' -__version__ = '1.2.0' -__build__ = 0x010200 +__version__ = '1.2.3' +__build__ = 0x010203 __author__ = 'Kenneth Reitz' __license__ = 'Apache 2.0' __copyright__ = 'Copyright 2013 Kenneth Reitz' +# Attempt to enable urllib3's SNI support, if possible +try: + from requests.packages.urllib3.contrib import pyopenssl + pyopenssl.inject_into_urllib3() +except ImportError: + pass from . import utils from .models import Request, Response, PreparedRequest diff --git a/requests/adapters.py b/requests/adapters.py index 5666e66..98b7317 100644 --- a/requests/adapters.py +++ b/requests/adapters.py @@ -25,6 +25,7 @@ from .cookies import extract_cookies_to_jar from .exceptions import ConnectionError, Timeout, SSLError from .auth import _basic_auth_str +DEFAULT_POOLBLOCK = False DEFAULT_POOLSIZE = 10 DEFAULT_RETRIES = 0 @@ -43,19 +44,41 @@ class BaseAdapter(object): class HTTPAdapter(BaseAdapter): - """Built-In HTTP Adapter for Urllib3.""" - __attrs__ = ['max_retries', 'config', '_pool_connections', '_pool_maxsize'] - - def __init__(self, pool_connections=DEFAULT_POOLSIZE, pool_maxsize=DEFAULT_POOLSIZE): - self.max_retries = DEFAULT_RETRIES + """The built-in HTTP Adapter for urllib3. + + Provides a general-case interface for Requests sessions to contact HTTP and + HTTPS urls by implementing the Transport Adapter interface. This class will + usually be created by the :class:`Session ` class under the + covers. + + :param pool_connections: The number of urllib3 connection pools to cache. + :param pool_maxsize: The maximum number of connections to save in the pool. + :param max_retries: The maximum number of retries each connection should attempt. + :param pool_block: Whether the connection pool should block for connections. + + Usage:: + + >>> import requests + >>> s = requests.Session() + >>> a = requests.adapters.HTTPAdapter() + >>> s.mount('http://', a) + """ + __attrs__ = ['max_retries', 'config', '_pool_connections', '_pool_maxsize', + '_pool_block'] + + def __init__(self, pool_connections=DEFAULT_POOLSIZE, + pool_maxsize=DEFAULT_POOLSIZE, max_retries=DEFAULT_RETRIES, + pool_block=DEFAULT_POOLBLOCK): + self.max_retries = max_retries self.config = {} super(HTTPAdapter, self).__init__() self._pool_connections = pool_connections self._pool_maxsize = pool_maxsize + self._pool_block = pool_block - self.init_poolmanager(pool_connections, pool_maxsize) + self.init_poolmanager(pool_connections, pool_maxsize, block=pool_block) def __getstate__(self): return dict((attr, getattr(self, attr, None)) for attr in @@ -65,16 +88,36 @@ class HTTPAdapter(BaseAdapter): for attr, value in state.items(): setattr(self, attr, value) - self.init_poolmanager(self._pool_connections, self._pool_maxsize) + self.init_poolmanager(self._pool_connections, self._pool_maxsize, + block=self._pool_block) + + def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK): + """Initializes a urllib3 PoolManager. This method should not be called + from user code, and is only exposed for use when subclassing the + :class:`HTTPAdapter `. - def init_poolmanager(self, connections, maxsize): + :param connections: The number of urllib3 connection pools to cache. + :param maxsize: The maximum number of connections to save in the pool. + :param block: Block when no free connections are available. + """ # save these values for pickling self._pool_connections = connections self._pool_maxsize = maxsize + self._pool_block = block - self.poolmanager = PoolManager(num_pools=connections, maxsize=maxsize) + self.poolmanager = PoolManager(num_pools=connections, maxsize=maxsize, + block=block) def cert_verify(self, conn, url, verify, cert): + """Verify a SSL certificate. This method should not be called from user + code, and is only exposed for use when subclassing the + :class:`HTTPAdapter `. + + :param conn: The urllib3 connection object associated with the cert. + :param url: The requested URL. + :param verify: Whether we should actually verify the certificate. + :param cert: The SSL certificate to verify. + """ if url.startswith('https') and verify: cert_loc = None @@ -103,6 +146,14 @@ class HTTPAdapter(BaseAdapter): conn.cert_file = cert def build_response(self, req, resp): + """Builds a :class:`Response ` object from a urllib3 + response. This should not be called from user code, and is only exposed + for use when subclassing the + :class:`HTTPAdapter ` + + :param req: The :class:`PreparedRequest ` used to generate the response. + :param resp: The urllib3 response object. + """ response = Response() # Fallback to None if there's no status_code, for whatever reason. @@ -131,7 +182,13 @@ class HTTPAdapter(BaseAdapter): return response def get_connection(self, url, proxies=None): - """Returns a connection for the given URL.""" + """Returns a urllib3 connection for the given URL. This should not be + called from user code, and is only exposed for use when subclassing the + :class:`HTTPAdapter `. + + :param url: The URL to connect to. + :param proxies: (optional) A Requests-style dictionary of proxies used on this request. + """ proxies = proxies or {} proxy = proxies.get(urlparse(url).scheme) @@ -144,7 +201,7 @@ class HTTPAdapter(BaseAdapter): return conn def close(self): - """Dispose of any internal state. + """Disposes of any internal state. Currently, this just closes the PoolManager, which closes pooled connections. @@ -155,7 +212,15 @@ class HTTPAdapter(BaseAdapter): """Obtain the url to use when making the final request. If the message is being sent through a proxy, the full URL has to be - used. Otherwise, we should only use the path portion of the URL.""" + used. Otherwise, we should only use the path portion of the URL. + + This shoudl not be called from user code, and is only exposed for use + when subclassing the + :class:`HTTPAdapter `. + + :param request: The :class:`PreparedRequest ` being sent. + :param proxies: A dictionary of schemes to proxy URLs. + """ proxies = proxies or {} proxy = proxies.get(urlparse(request.url).scheme) @@ -168,7 +233,15 @@ class HTTPAdapter(BaseAdapter): def add_headers(self, request, **kwargs): """Add any headers needed by the connection. Currently this adds a - Proxy-Authorization header.""" + Proxy-Authorization header. + + This should not be called from user code, and is only exposed for use + when subclassing the + :class:`HTTPAdapter `. + + :param request: The :class:`PreparedRequest ` to add headers to. + :param kwargs: The keyword arguments from the call to send(). + """ proxies = kwargs.get('proxies', {}) if proxies is None: @@ -186,7 +259,15 @@ class HTTPAdapter(BaseAdapter): password) def send(self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None): - """Sends PreparedRequest object. Returns Response object.""" + """Sends PreparedRequest object. Returns Response object. + + :param request: The :class:`PreparedRequest ` being sent. + :param stream: (optional) Whether to stream the request content. + :param timeout: (optional) The timeout on the request. + :param verify: (optional) Whether to verify SSL certificates. + :param vert: (optional) Any user-provided SSL certificate to be trusted. + :param proxies: (optional) The proxies dictionary to apply to the request. + """ conn = self.get_connection(request.url, proxies) diff --git a/requests/auth.py b/requests/auth.py index 805f240..fab05cf 100644 --- a/requests/auth.py +++ b/requests/auth.py @@ -8,6 +8,7 @@ This module contains the authentication handlers for Requests. """ import os +import re import time import hashlib import logging @@ -49,7 +50,7 @@ class HTTPBasicAuth(AuthBase): class HTTPProxyAuth(HTTPBasicAuth): - """Attaches HTTP Proxy Authenetication to a given Request object.""" + """Attaches HTTP Proxy Authentication to a given Request object.""" def __call__(self, r): r.headers['Proxy-Authorization'] = _basic_auth_str(self.username, self.password) return r @@ -151,7 +152,8 @@ class HTTPDigestAuth(AuthBase): if 'digest' in s_auth.lower() and num_401_calls < 2: setattr(self, 'num_401_calls', num_401_calls + 1) - self.chal = parse_dict_header(s_auth.replace('Digest ', '')) + pat = re.compile(r'digest ', flags=re.IGNORECASE) + self.chal = parse_dict_header(pat.sub('', s_auth, count=1)) # Consume content and release the original connection # to allow our new request to reuse the same one. diff --git a/requests/cookies.py b/requests/cookies.py index 1235711..d759d0a 100644 --- a/requests/cookies.py +++ b/requests/cookies.py @@ -69,6 +69,10 @@ class MockRequest(object): def unverifiable(self): return self.is_unverifiable() + @property + def origin_req_host(self): + return self.get_origin_req_host() + class MockResponse(object): """Wraps a `httplib.HTTPMessage` to mimic a `urllib.addinfourl`. diff --git a/requests/models.py b/requests/models.py index 6ed2b59..6cf2aaa 100644 --- a/requests/models.py +++ b/requests/models.py @@ -18,9 +18,10 @@ from .structures import CaseInsensitiveDict from .auth import HTTPBasicAuth from .cookies import cookiejar_from_dict, get_cookie_header from .packages.urllib3.filepost import encode_multipart_formdata +from .packages.urllib3.util import parse_url from .exceptions import HTTPError, RequestException, MissingSchema, InvalidURL from .utils import ( - stream_untransfer, guess_filename, get_auth_from_url, requote_uri, + guess_filename, get_auth_from_url, requote_uri, stream_decode_response_unicode, to_key_val_list, parse_header_links, iter_slices, guess_json_utf, super_len) from .compat import ( @@ -60,7 +61,7 @@ class RequestEncodingMixin(object): """Encode parameters in a piece of data. Will successfully encode parameters when passed as a dict or a list of - 2-tuples. Order is retained if data is a list of 2-tuples but abritrary + 2-tuples. Order is retained if data is a list of 2-tuples but arbitrary if parameters are supplied as a dict. """ @@ -99,11 +100,13 @@ class RequestEncodingMixin(object): files = to_key_val_list(files or {}) for field, val in fields: - if isinstance(val, list): - for v in val: - new_fields.append((field, builtin_str(v))) - else: - new_fields.append((field, builtin_str(val))) + if isinstance(val, basestring) or not hasattr(val, '__iter__'): + val = [val] + for v in val: + if v is not None: + new_fields.append( + (field.decode('utf-8') if isinstance(field, bytes) else field, + v.encode('utf-8') if isinstance(v, str) else v)) for (k, v) in files: # support for explicit filename @@ -282,16 +285,28 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin): pass # Support for unicode domain names and paths. - scheme, netloc, path, _params, query, fragment = urlparse(url) + scheme, auth, host, port, path, query, fragment = parse_url(url) - if not (scheme and netloc): + if not scheme: raise MissingSchema("Invalid URL %r: No schema supplied" % url) + if not host: + raise InvalidURL("Invalid URL %r: No host supplied" % url) + + # Only want to apply IDNA to the hostname try: - netloc = netloc.encode('idna').decode('utf-8') + host = host.encode('idna').decode('utf-8') except UnicodeError: raise InvalidURL('URL has an invalid label.') + # Carefully reconstruct the network location + netloc = auth or '' + if netloc: + netloc += '@' + netloc += host + if port: + netloc += ':' + str(port) + # Bare domains aren't valid URLs. if not path: path = '/' @@ -303,8 +318,6 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin): netloc = netloc.encode('utf-8') if isinstance(path, str): path = path.encode('utf-8') - if isinstance(_params, str): - _params = _params.encode('utf-8') if isinstance(query, str): query = query.encode('utf-8') if isinstance(fragment, str): @@ -317,7 +330,7 @@ class PreparedRequest(RequestEncodingMixin, RequestHooksMixin): else: query = enc_params - url = requote_uri(urlunparse([scheme, netloc, path, _params, query, fragment])) + url = requote_uri(urlunparse([scheme, netloc, path, None, query, fragment])) self.url = url def prepare_headers(self, headers): @@ -525,13 +538,13 @@ class Response(object): def generate(): while 1: - chunk = self.raw.read(chunk_size) + chunk = self.raw.read(chunk_size, decode_content=True) if not chunk: break yield chunk self._content_consumed = True - gen = stream_untransfer(generate(), self) + gen = generate() if decode_unicode: gen = stream_decode_response_unicode(gen, self) @@ -575,7 +588,7 @@ class Response(object): raise RuntimeError( 'The content for this response was already consumed') - if self.status_code is 0: + if self.status_code == 0: self._content = None else: self._content = bytes().join(self.iter_content(CONTENT_CHUNK_SIZE)) or bytes() @@ -641,7 +654,7 @@ class Response(object): def links(self): """Returns the parsed header links of the response, if any.""" - header = self.headers['link'] + header = self.headers.get('link') # l = MultiDict() l = {} diff --git a/requests/packages/urllib3/__init__.py b/requests/packages/urllib3/__init__.py index 55de87e..bff80b8 100644 --- a/requests/packages/urllib3/__init__.py +++ b/requests/packages/urllib3/__init__.py @@ -1,5 +1,5 @@ # urllib3/__init__.py -# Copyright 2008-2012 Andrey Petrov and contributors (see CONTRIBUTORS.txt) +# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt) # # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/requests/packages/urllib3/_collections.py b/requests/packages/urllib3/_collections.py index a052b1d..b35a736 100644 --- a/requests/packages/urllib3/_collections.py +++ b/requests/packages/urllib3/_collections.py @@ -1,5 +1,5 @@ # urllib3/_collections.py -# Copyright 2008-2012 Andrey Petrov and contributors (see CONTRIBUTORS.txt) +# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt) # # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/requests/packages/urllib3/connectionpool.py b/requests/packages/urllib3/connectionpool.py index 51c87f5..f3e9260 100644 --- a/requests/packages/urllib3/connectionpool.py +++ b/requests/packages/urllib3/connectionpool.py @@ -1,5 +1,5 @@ # urllib3/connectionpool.py -# Copyright 2008-2012 Andrey Petrov and contributors (see CONTRIBUTORS.txt) +# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt) # # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -9,7 +9,7 @@ import socket import errno from socket import error as SocketError, timeout as SocketTimeout -from .util import resolve_cert_reqs, resolve_ssl_version +from .util import resolve_cert_reqs, resolve_ssl_version, assert_fingerprint try: # Python 3 from http.client import HTTPConnection, HTTPException @@ -81,12 +81,15 @@ class VerifiedHTTPSConnection(HTTPSConnection): ssl_version = None def set_cert(self, key_file=None, cert_file=None, - cert_reqs=None, ca_certs=None): + cert_reqs=None, ca_certs=None, + assert_hostname=None, assert_fingerprint=None): self.key_file = key_file self.cert_file = cert_file self.cert_reqs = cert_reqs self.ca_certs = ca_certs + self.assert_hostname = assert_hostname + self.assert_fingerprint = assert_fingerprint def connect(self): # Add certificate verification @@ -104,8 +107,12 @@ class VerifiedHTTPSConnection(HTTPSConnection): ssl_version=resolved_ssl_version) if resolved_cert_reqs != ssl.CERT_NONE: - match_hostname(self.sock.getpeercert(), self.host) - + if self.assert_fingerprint: + assert_fingerprint(self.sock.getpeercert(binary_form=True), + self.assert_fingerprint) + else: + match_hostname(self.sock.getpeercert(), + self.assert_hostname or self.host) ## Pool objects @@ -439,12 +446,14 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods): except Empty as e: # Timed out by queue - raise TimeoutError(self, "Request timed out. (pool_timeout=%s)" % + raise TimeoutError(self, url, + "Request timed out. (pool_timeout=%s)" % pool_timeout) except SocketTimeout as e: # Timed out by socket - raise TimeoutError(self, "Request timed out. (timeout=%s)" % + raise TimeoutError(self, url, + "Request timed out. (timeout=%s)" % timeout) except BaseSSLError as e: @@ -502,9 +511,13 @@ class HTTPSConnectionPool(HTTPConnectionPool): :class:`.VerifiedHTTPSConnection` is used, which *can* verify certificates, instead of :class:`httplib.HTTPSConnection`. - The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``, and ``ssl_version`` - are only used if :mod:`ssl` is available and are fed into - :meth:`urllib3.util.ssl_wrap_socket` to upgrade the connection socket into an SSL socket. + :class:`.VerifiedHTTPSConnection` uses one of ``assert_fingerprint``, + ``assert_hostname`` and ``host`` in this order to verify connections. + + The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs`` and + ``ssl_version`` are only used if :mod:`ssl` is available and are fed into + :meth:`urllib3.util.ssl_wrap_socket` to upgrade the connection socket + into an SSL socket. """ scheme = 'https' @@ -512,8 +525,9 @@ class HTTPSConnectionPool(HTTPConnectionPool): def __init__(self, host, port=None, strict=False, timeout=None, maxsize=1, block=False, headers=None, - key_file=None, cert_file=None, - cert_reqs=None, ca_certs=None, ssl_version=None): + key_file=None, cert_file=None, cert_reqs=None, + ca_certs=None, ssl_version=None, + assert_hostname=None, assert_fingerprint=None): HTTPConnectionPool.__init__(self, host, port, strict, timeout, maxsize, @@ -523,6 +537,8 @@ class HTTPSConnectionPool(HTTPConnectionPool): self.cert_reqs = cert_reqs self.ca_certs = ca_certs self.ssl_version = ssl_version + self.assert_hostname = assert_hostname + self.assert_fingerprint = assert_fingerprint def _new_conn(self): """ @@ -532,7 +548,7 @@ class HTTPSConnectionPool(HTTPConnectionPool): log.info("Starting new HTTPS connection (%d): %s" % (self.num_connections, self.host)) - if not ssl: # Platform-specific: Python compiled without +ssl + if not ssl: # Platform-specific: Python compiled without +ssl if not HTTPSConnection or HTTPSConnection is object: raise SSLError("Can't connect to HTTPS URL because the SSL " "module is not available.") @@ -545,7 +561,9 @@ class HTTPSConnectionPool(HTTPConnectionPool): port=self.port, strict=self.strict) connection.set_cert(key_file=self.key_file, cert_file=self.cert_file, - cert_reqs=self.cert_reqs, ca_certs=self.ca_certs) + cert_reqs=self.cert_reqs, ca_certs=self.ca_certs, + assert_hostname=self.assert_hostname, + assert_fingerprint=self.assert_fingerprint) connection.ssl_version = self.ssl_version diff --git a/requests/packages/urllib3/contrib/__init__.py b/requests/packages/urllib3/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/requests/packages/urllib3/contrib/ntlmpool.py b/requests/packages/urllib3/contrib/ntlmpool.py new file mode 100644 index 0000000..277ee0b --- /dev/null +++ b/requests/packages/urllib3/contrib/ntlmpool.py @@ -0,0 +1,120 @@ +# urllib3/contrib/ntlmpool.py +# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt) +# +# This module is part of urllib3 and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +NTLM authenticating pool, contributed by erikcederstran + +Issue #10, see: http://code.google.com/p/urllib3/issues/detail?id=10 +""" + +try: + from http.client import HTTPSConnection +except ImportError: + from httplib import HTTPSConnection +from logging import getLogger +from ntlm import ntlm + +from urllib3 import HTTPSConnectionPool + + +log = getLogger(__name__) + + +class NTLMConnectionPool(HTTPSConnectionPool): + """ + Implements an NTLM authentication version of an urllib3 connection pool + """ + + scheme = 'https' + + def __init__(self, user, pw, authurl, *args, **kwargs): + """ + authurl is a random URL on the server that is protected by NTLM. + user is the Windows user, probably in the DOMAIN\username format. + pw is the password for the user. + """ + super(NTLMConnectionPool, self).__init__(*args, **kwargs) + self.authurl = authurl + self.rawuser = user + user_parts = user.split('\\', 1) + self.domain = user_parts[0].upper() + self.user = user_parts[1] + self.pw = pw + + def _new_conn(self): + # Performs the NTLM handshake that secures the connection. The socket + # must be kept open while requests are performed. + self.num_connections += 1 + log.debug('Starting NTLM HTTPS connection no. %d: https://%s%s' % + (self.num_connections, self.host, self.authurl)) + + headers = {} + headers['Connection'] = 'Keep-Alive' + req_header = 'Authorization' + resp_header = 'www-authenticate' + + conn = HTTPSConnection(host=self.host, port=self.port) + + # Send negotiation message + headers[req_header] = ( + 'NTLM %s' % ntlm.create_NTLM_NEGOTIATE_MESSAGE(self.rawuser)) + log.debug('Request headers: %s' % headers) + conn.request('GET', self.authurl, None, headers) + res = conn.getresponse() + reshdr = dict(res.getheaders()) + log.debug('Response status: %s %s' % (res.status, res.reason)) + log.debug('Response headers: %s' % reshdr) + log.debug('Response data: %s [...]' % res.read(100)) + + # Remove the reference to the socket, so that it can not be closed by + # the response object (we want to keep the socket open) + res.fp = None + + # Server should respond with a challenge message + auth_header_values = reshdr[resp_header].split(', ') + auth_header_value = None + for s in auth_header_values: + if s[:5] == 'NTLM ': + auth_header_value = s[5:] + if auth_header_value is None: + raise Exception('Unexpected %s response header: %s' % + (resp_header, reshdr[resp_header])) + + # Send authentication message + ServerChallenge, NegotiateFlags = \ + ntlm.parse_NTLM_CHALLENGE_MESSAGE(auth_header_value) + auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(ServerChallenge, + self.user, + self.domain, + self.pw, + NegotiateFlags) + headers[req_header] = 'NTLM %s' % auth_msg + log.debug('Request headers: %s' % headers) + conn.request('GET', self.authurl, None, headers) + res = conn.getresponse() + log.debug('Response status: %s %s' % (res.status, res.reason)) + log.debug('Response headers: %s' % dict(res.getheaders())) + log.debug('Response data: %s [...]' % res.read()[:100]) + if res.status != 200: + if res.status == 401: + raise Exception('Server rejected request: wrong ' + 'username or password') + raise Exception('Wrong server response: %s %s' % + (res.status, res.reason)) + + res.fp = None + log.debug('Connection established') + return conn + + def urlopen(self, method, url, body=None, headers=None, retries=3, + redirect=True, assert_same_host=True): + if headers is None: + headers = {} + headers['Connection'] = 'Keep-Alive' + return super(NTLMConnectionPool, self).urlopen(method, url, body, + headers, retries, + redirect, + assert_same_host) diff --git a/requests/packages/urllib3/contrib/pyopenssl.py b/requests/packages/urllib3/contrib/pyopenssl.py new file mode 100644 index 0000000..5c4c6d8 --- /dev/null +++ b/requests/packages/urllib3/contrib/pyopenssl.py @@ -0,0 +1,167 @@ +'''SSL with SNI-support for Python 2. + +This needs the following packages installed: + +* pyOpenSSL (tested with 0.13) +* ndg-httpsclient (tested with 0.3.2) +* pyasn1 (tested with 0.1.6) + +To activate it call :func:`~urllib3.contrib.pyopenssl.inject_into_urllib3`. +This can be done in a ``sitecustomize`` module, or at any other time before +your application begins using ``urllib3``, like this:: + + try: + import urllib3.contrib.pyopenssl + urllib3.contrib.pyopenssl.inject_into_urllib3() + except ImportError: + pass + +Now you can use :mod:`urllib3` as you normally would, and it will support SNI +when the required modules are installed. +''' + +from ndg.httpsclient.ssl_peer_verification import (ServerSSLCertVerification, + SUBJ_ALT_NAME_SUPPORT) +from ndg.httpsclient.subj_alt_name import SubjectAltName +import OpenSSL.SSL +from pyasn1.codec.der import decoder as der_decoder +from socket import _fileobject +import ssl + +from .. import connectionpool +from .. import util + +__all__ = ['inject_into_urllib3', 'extract_from_urllib3'] + +# SNI only *really* works if we can read the subjectAltName of certificates. +HAS_SNI = SUBJ_ALT_NAME_SUPPORT + +# Map from urllib3 to PyOpenSSL compatible parameter-values. +_openssl_versions = { + ssl.PROTOCOL_SSLv23: OpenSSL.SSL.SSLv23_METHOD, + ssl.PROTOCOL_SSLv3: OpenSSL.SSL.SSLv3_METHOD, + ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, +} +_openssl_verify = { + ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, + ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, + ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER + + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, +} + + +orig_util_HAS_SNI = util.HAS_SNI +orig_connectionpool_ssl_wrap_socket = connectionpool.ssl_wrap_socket + + +def inject_into_urllib3(): + 'Monkey-patch urllib3 with PyOpenSSL-backed SSL-support.' + + connectionpool.ssl_wrap_socket = ssl_wrap_socket + util.HAS_SNI = HAS_SNI + + +def extract_from_urllib3(): + 'Undo monkey-patching by :func:`inject_into_urllib3`.' + + connectionpool.ssl_wrap_socket = orig_connectionpool_ssl_wrap_socket + util.HAS_SNI = orig_util_HAS_SNI + + +### Note: This is a slightly bug-fixed version of same from ndg-httpsclient. +def get_subj_alt_name(peer_cert): + # Search through extensions + dns_name = [] + if not SUBJ_ALT_NAME_SUPPORT: + return dns_name + + general_names = SubjectAltName() + for i in range(peer_cert.get_extension_count()): + ext = peer_cert.get_extension(i) + ext_name = ext.get_short_name() + if ext_name != 'subjectAltName': + continue + + # PyOpenSSL returns extension data in ASN.1 encoded form + ext_dat = ext.get_data() + decoded_dat = der_decoder.decode(ext_dat, + asn1Spec=general_names) + + for name in decoded_dat: + if not isinstance(name, SubjectAltName): + continue + for entry in range(len(name)): + component = name.getComponentByPosition(entry) + if component.getName() != 'dNSName': + continue + dns_name.append(str(component.getComponent())) + + return dns_name + + +class WrappedSocket(object): + '''API-compatibility wrapper for Python OpenSSL's Connection-class.''' + + def __init__(self, connection, socket): + self.connection = connection + self.socket = socket + + def makefile(self, mode, bufsize=-1): + return _fileobject(self.connection, mode, bufsize) + + def settimeout(self, timeout): + return self.socket.settimeout(timeout) + + def sendall(self, data): + return self.connection.sendall(data) + + def getpeercert(self, binary_form=False): + x509 = self.connection.get_peer_certificate() + if not x509: + raise ssl.SSLError('') + + if binary_form: + return OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_ASN1, + x509) + + return { + 'subject': ( + (('commonName', x509.get_subject().CN),), + ), + 'subjectAltName': [ + ('DNS', value) + for value in get_subj_alt_name(x509) + ] + } + + +def _verify_callback(cnx, x509, err_no, err_depth, return_code): + return err_no == 0 + + +def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None, + ca_certs=None, server_hostname=None, + ssl_version=None): + ctx = OpenSSL.SSL.Context(_openssl_versions[ssl_version]) + if certfile: + ctx.use_certificate_file(certfile) + if keyfile: + ctx.use_privatekey_file(keyfile) + if cert_reqs != ssl.CERT_NONE: + ctx.set_verify(_openssl_verify[cert_reqs], _verify_callback) + if ca_certs: + try: + ctx.load_verify_locations(ca_certs, None) + except OpenSSL.SSL.Error as e: + raise ssl.SSLError('bad ca_certs: %r' % ca_certs, e) + + cnx = OpenSSL.SSL.Connection(ctx, sock) + cnx.set_tlsext_host_name(server_hostname) + cnx.set_connect_state() + try: + cnx.do_handshake() + except OpenSSL.SSL.Error as e: + raise ssl.SSLError('bad handshake', e) + + return WrappedSocket(cnx, sock) diff --git a/requests/packages/urllib3/exceptions.py b/requests/packages/urllib3/exceptions.py index c5eb962..2e2a259 100644 --- a/requests/packages/urllib3/exceptions.py +++ b/requests/packages/urllib3/exceptions.py @@ -1,5 +1,5 @@ # urllib3/exceptions.py -# Copyright 2008-2012 Andrey Petrov and contributors (see CONTRIBUTORS.txt) +# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt) # # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -20,7 +20,18 @@ class PoolError(HTTPError): def __reduce__(self): # For pickling purposes. - return self.__class__, (None, self.url) + return self.__class__, (None, None) + + +class RequestError(PoolError): + "Base exception for PoolErrors that have associated URLs." + def __init__(self, pool, url, message): + self.url = url + PoolError.__init__(self, pool, message) + + def __reduce__(self): + # For pickling purposes. + return self.__class__, (None, self.url, None) class SSLError(HTTPError): @@ -35,7 +46,7 @@ class DecodeError(HTTPError): ## Leaf Exceptions -class MaxRetryError(PoolError): +class MaxRetryError(RequestError): "Raised when the maximum number of retries is exceeded." def __init__(self, pool, url, reason=None): @@ -47,22 +58,19 @@ class MaxRetryError(PoolError): else: message += " (Caused by redirect)" - PoolError.__init__(self, pool, message) - self.url = url + RequestError.__init__(self, pool, url, message) -class HostChangedError(PoolError): +class HostChangedError(RequestError): "Raised when an existing pool gets a request for a foreign host." def __init__(self, pool, url, retries=3): message = "Tried to open a foreign host with url: %s" % url - PoolError.__init__(self, pool, message) - - self.url = url + RequestError.__init__(self, pool, url, message) self.retries = retries -class TimeoutError(PoolError): +class TimeoutError(RequestError): "Raised when a socket timeout occurs." pass diff --git a/requests/packages/urllib3/filepost.py b/requests/packages/urllib3/filepost.py index 8d900bd..470309a 100644 --- a/requests/packages/urllib3/filepost.py +++ b/requests/packages/urllib3/filepost.py @@ -93,6 +93,6 @@ def encode_multipart_formdata(fields, boundary=None): body.write(b('--%s--\r\n' % (boundary))) - content_type = b('multipart/form-data; boundary=%s' % boundary) + content_type = str('multipart/form-data; boundary=%s' % boundary) return body.getvalue(), content_type diff --git a/requests/packages/urllib3/poolmanager.py b/requests/packages/urllib3/poolmanager.py index 6e7377c..ce0c248 100644 --- a/requests/packages/urllib3/poolmanager.py +++ b/requests/packages/urllib3/poolmanager.py @@ -1,5 +1,5 @@ # urllib3/poolmanager.py -# Copyright 2008-2012 Andrey Petrov and contributors (see CONTRIBUTORS.txt) +# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt) # # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -23,6 +23,9 @@ pool_classes_by_scheme = { log = logging.getLogger(__name__) +SSL_KEYWORDS = ('key_file', 'cert_file', 'cert_reqs', 'ca_certs', + 'ssl_version') + class PoolManager(RequestMethods): """ @@ -67,7 +70,13 @@ class PoolManager(RequestMethods): to be overridden for customization. """ pool_cls = pool_classes_by_scheme[scheme] - return pool_cls(host, port, **self.connection_pool_kw) + kwargs = self.connection_pool_kw + if scheme == 'http': + kwargs = self.connection_pool_kw.copy() + for kw in SSL_KEYWORDS: + kwargs.pop(kw, None) + + return pool_cls(host, port, **kwargs) def clear(self): """ @@ -141,6 +150,7 @@ class PoolManager(RequestMethods): log.info("Redirecting %s -> %s" % (url, redirect_location)) kw['retries'] = kw.get('retries', 3) - 1 # Persist retries countdown + kw['redirect'] = redirect return self.urlopen(method, redirect_location, **kw) diff --git a/requests/packages/urllib3/request.py b/requests/packages/urllib3/request.py index 2b4704e..bf0256e 100644 --- a/requests/packages/urllib3/request.py +++ b/requests/packages/urllib3/request.py @@ -1,5 +1,5 @@ # urllib3/request.py -# Copyright 2008-2012 Andrey Petrov and contributors (see CONTRIBUTORS.txt) +# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt) # # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php diff --git a/requests/packages/urllib3/response.py b/requests/packages/urllib3/response.py index 0761dc0..2fa4078 100644 --- a/requests/packages/urllib3/response.py +++ b/requests/packages/urllib3/response.py @@ -4,29 +4,48 @@ # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import gzip + import logging import zlib -from io import BytesIO - from .exceptions import DecodeError -from .packages.six import string_types as basestring +from .packages.six import string_types as basestring, binary_type log = logging.getLogger(__name__) -def decode_gzip(data): - gzipper = gzip.GzipFile(fileobj=BytesIO(data)) - return gzipper.read() +class DeflateDecoder(object): + + def __init__(self): + self._first_try = True + self._data = binary_type() + self._obj = zlib.decompressobj() + def __getattr__(self, name): + return getattr(self._obj, name) -def decode_deflate(data): - try: - return zlib.decompress(data) - except zlib.error: - return zlib.decompress(data, -zlib.MAX_WBITS) + def decompress(self, data): + if not self._first_try: + return self._obj.decompress(data) + + self._data += data + try: + return self._obj.decompress(data) + except zlib.error: + self._first_try = False + self._obj = zlib.decompressobj(-zlib.MAX_WBITS) + try: + return self.decompress(self._data) + finally: + self._data = None + + +def _get_decoder(mode): + if mode == 'gzip': + return zlib.decompressobj(16 + zlib.MAX_WBITS) + + return DeflateDecoder() class HTTPResponse(object): @@ -52,10 +71,7 @@ class HTTPResponse(object): otherwise unused. """ - CONTENT_DECODERS = { - 'gzip': decode_gzip, - 'deflate': decode_deflate, - } + CONTENT_DECODERS = ['gzip', 'deflate'] def __init__(self, body='', headers=None, status=0, version=0, reason=None, strict=0, preload_content=True, decode_content=True, @@ -65,8 +81,9 @@ class HTTPResponse(object): self.version = version self.reason = reason self.strict = strict + self.decode_content = decode_content - self._decode_content = decode_content + self._decoder = None self._body = body if body and isinstance(body, basestring) else None self._fp = None self._original_response = original_response @@ -115,13 +132,13 @@ class HTTPResponse(object): parameters: ``decode_content`` and ``cache_content``. :param amt: - How much of the content to read. If specified, decoding and caching - is skipped because we can't decode partial content nor does it make - sense to cache partial content as the full response. + How much of the content to read. If specified, caching is skipped + because it doesn't make sense to cache partial content as the full + response. :param decode_content: If True, will attempt to decode the body based on the - 'content-encoding' header. (Overridden if ``amt`` is set.) + 'content-encoding' header. :param cache_content: If True, will save the returned data such that the same result is @@ -133,18 +150,24 @@ class HTTPResponse(object): # Note: content-encoding value should be case-insensitive, per RFC 2616 # Section 3.5 content_encoding = self.headers.get('content-encoding', '').lower() - decoder = self.CONTENT_DECODERS.get(content_encoding) + if self._decoder is None: + if content_encoding in self.CONTENT_DECODERS: + self._decoder = _get_decoder(content_encoding) if decode_content is None: - decode_content = self._decode_content + decode_content = self.decode_content if self._fp is None: return + flush_decoder = False + try: if amt is None: # cStringIO doesn't like amt=None data = self._fp.read() + flush_decoder = True else: + cache_content = False data = self._fp.read(amt) if amt != 0 and not data: # Platform-specific: Buggy versions of Python. # Close the connection when no data is returned @@ -155,15 +178,19 @@ class HTTPResponse(object): # properly close the connection in all cases. There is no harm # in redundantly calling close. self._fp.close() - return data + flush_decoder = True try: - if decode_content and decoder: - data = decoder(data) + if decode_content and self._decoder: + data = self._decoder.decompress(data) except (IOError, zlib.error): raise DecodeError("Received response with content-encoding: %s, but " "failed to decode it." % content_encoding) + if flush_decoder and self._decoder: + buf = self._decoder.decompress(binary_type()) + data += buf + self._decoder.flush() + if cache_content: self._body = data diff --git a/requests/packages/urllib3/util.py b/requests/packages/urllib3/util.py index b827bc4..544f9ed 100644 --- a/requests/packages/urllib3/util.py +++ b/requests/packages/urllib3/util.py @@ -1,5 +1,5 @@ # urllib3/util.py -# Copyright 2008-2012 Andrey Petrov and contributors (see CONTRIBUTORS.txt) +# Copyright 2008-2013 Andrey Petrov and contributors (see CONTRIBUTORS.txt) # # This module is part of urllib3 and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php @@ -8,6 +8,8 @@ from base64 import b64encode from collections import namedtuple from socket import error as SocketError +from hashlib import md5, sha1 +from binascii import hexlify, unhexlify try: from select import poll, POLLIN @@ -23,7 +25,7 @@ try: # Test for SSL features HAS_SNI = False import ssl - from ssl import wrap_socket, CERT_NONE, SSLError, PROTOCOL_SSLv23 + from ssl import wrap_socket, CERT_NONE, PROTOCOL_SSLv23 from ssl import SSLContext # Modern SSL? from ssl import HAS_SNI # Has SNI? except ImportError: @@ -31,7 +33,7 @@ except ImportError: from .packages import six -from .exceptions import LocationParseError +from .exceptions import LocationParseError, SSLError class Url(namedtuple('Url', ['scheme', 'auth', 'host', 'port', 'path', 'query', 'fragment'])): @@ -232,7 +234,7 @@ def make_headers(keep_alive=None, accept_encoding=None, user_agent=None, return headers -def is_connection_dropped(conn): +def is_connection_dropped(conn): # Platform-specific """ Returns True if the connection is dropped and should be closed. @@ -246,7 +248,7 @@ def is_connection_dropped(conn): if not sock: # Platform-specific: AppEngine return False - if not poll: # Platform-specific + if not poll: if not select: # Platform-specific: AppEngine return False @@ -302,6 +304,44 @@ def resolve_ssl_version(candidate): return candidate + +def assert_fingerprint(cert, fingerprint): + """ + Checks if given fingerprint matches the supplied certificate. + + :param cert: + Certificate as bytes object. + :param fingerprint: + Fingerprint as string of hexdigits, can be interspersed by colons. + """ + + # Maps the length of a digest to a possible hash function producing + # this digest. + hashfunc_map = { + 16: md5, + 20: sha1 + } + + fingerprint = fingerprint.replace(':', '').lower() + + digest_length, rest = divmod(len(fingerprint), 2) + + if rest or digest_length not in hashfunc_map: + raise SSLError('Fingerprint is of invalid length.') + + # We need encode() here for py32; works on py2 and p33. + fingerprint_bytes = unhexlify(fingerprint.encode()) + + hashfunc = hashfunc_map[digest_length] + + cert_digest = hashfunc(cert).digest() + + if not cert_digest == fingerprint_bytes: + raise SSLError('Fingerprints did not match. Expected "{0}", got "{1}".' + .format(hexlify(fingerprint_bytes), + hexlify(cert_digest))) + + if SSLContext is not None: # Python 3.2+ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None, ca_certs=None, server_hostname=None, diff --git a/requests/sessions.py b/requests/sessions.py index de0d9d6..f4aeeee 100644 --- a/requests/sessions.py +++ b/requests/sessions.py @@ -9,16 +9,17 @@ requests (cookies, auth, proxies). """ import os +from collections import Mapping from datetime import datetime -from .compat import cookielib -from .cookies import cookiejar_from_dict +from .compat import cookielib, OrderedDict, urljoin, urlparse +from .cookies import cookiejar_from_dict, extract_cookies_to_jar, RequestsCookieJar from .models import Request, PreparedRequest from .hooks import default_hooks, dispatch_hook -from .utils import from_key_val_list, default_headers +from .utils import to_key_val_list, default_headers from .exceptions import TooManyRedirects, InvalidSchema +from .structures import CaseInsensitiveDict -from .compat import urlparse, urljoin from .adapters import HTTPAdapter from .utils import requote_uri, get_environ_proxies, get_netrc_auth @@ -33,49 +34,35 @@ REDIRECT_STATI = ( DEFAULT_REDIRECT_LIMIT = 30 -def merge_kwargs(local_kwarg, default_kwarg): - """Merges kwarg dictionaries. - - If a local key in the dictionary is set to None, it will be removed. +def merge_setting(request_setting, session_setting, dict_class=OrderedDict): + """ + Determines appropriate setting for a given request, taking into account the + explicit setting on that request, and the setting in the session. If a + setting is a dictionary, they will be merged together using `dict_class` """ - if default_kwarg is None: - return local_kwarg - - if isinstance(local_kwarg, str): - return local_kwarg - - if local_kwarg is None: - return default_kwarg - - # Bypass if not a dictionary (e.g. timeout) - if not hasattr(default_kwarg, 'items'): - return local_kwarg + if session_setting is None: + return request_setting - default_kwarg = from_key_val_list(default_kwarg) - local_kwarg = from_key_val_list(local_kwarg) + if request_setting is None: + return session_setting - # Update new values in a case-insensitive way - def get_original_key(original_keys, new_key): - """ - Finds the key from original_keys that case-insensitive matches new_key. - """ - for original_key in original_keys: - if key.lower() == original_key.lower(): - return original_key - return new_key + # Bypass if not a dictionary (e.g. verify) + if not ( + isinstance(session_setting, Mapping) and + isinstance(request_setting, Mapping) + ): + return request_setting - kwargs = default_kwarg.copy() - original_keys = kwargs.keys() - for key, value in local_kwarg.items(): - kwargs[get_original_key(original_keys, key)] = value + merged_setting = dict_class(to_key_val_list(session_setting)) + merged_setting.update(to_key_val_list(request_setting)) # Remove keys that are set to None. - for (k, v) in local_kwarg.items(): + for (k, v) in request_setting.items(): if v is None: - del kwargs[k] + del merged_setting[k] - return kwargs + return merged_setting class SessionRedirectMixin(object): @@ -91,10 +78,6 @@ class SessionRedirectMixin(object): prepared_request.method = req.method prepared_request.url = req.url - cookiejar = cookiejar_from_dict({}) - cookiejar.update(self.cookies) - cookiejar.update(resp.cookies) - # ((resp.status_code is codes.see_other)) while (('location' in resp.headers and resp.status_code in REDIRECT_STATI)): @@ -116,9 +99,11 @@ class SessionRedirectMixin(object): # Facilitate non-RFC2616-compliant 'location' headers # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource') + # Compliant with RFC3986, we percent encode the url. if not urlparse(url).netloc: - # Compliant with RFC3986, we percent encode the url. url = urljoin(resp.url, requote_uri(url)) + else: + url = requote_uri(url) prepared_request.url = url @@ -129,7 +114,7 @@ class SessionRedirectMixin(object): # Do what the browsers do, despite standards... if (resp.status_code in (codes.moved, codes.found) and - prepared_request.method == 'POST'): + prepared_request.method not in ('GET', 'HEAD')): method = 'GET' prepared_request.method = method @@ -147,7 +132,7 @@ class SessionRedirectMixin(object): except KeyError: pass - prepared_request.prepare_cookies(cookiejar) + prepared_request.prepare_cookies(self.cookies) resp = self.send( prepared_request, @@ -159,13 +144,11 @@ class SessionRedirectMixin(object): allow_redirects=False, ) - cookiejar.update(resp.cookies) + extract_cookies_to_jar(self.cookies, prepared_request, resp.raw) i += 1 yield resp - resp.cookies.update(cookiejar) - class Session(SessionRedirectMixin): """A Requests session. @@ -218,7 +201,8 @@ class Session(SessionRedirectMixin): #: SSL certificate default. self.cert = None - #: Maximum number of redirects to follow. + #: Maximum number of redirects allowed. If the request exceeds this + #: limit, a :class:`TooManyRedirects` exception is raised. self.max_redirects = DEFAULT_REDIRECT_LIMIT #: Should we trust the environment? @@ -228,9 +212,9 @@ class Session(SessionRedirectMixin): self.cookies = cookiejar_from_dict({}) # Default connection adapters. - self.adapters = {} - self.mount('http://', HTTPAdapter()) + self.adapters = OrderedDict() self.mount('https://', HTTPAdapter()) + self.mount('http://', HTTPAdapter()) def __enter__(self): return self @@ -274,12 +258,8 @@ class Session(SessionRedirectMixin): :param allow_redirects: (optional) Boolean. Set to True by default. :param proxies: (optional) Dictionary mapping protocol to the URL of the proxy. - :param return_response: (optional) If False, an un-sent Request object - will returned. - :param config: (optional) A configuration dictionary. See - ``request.defaults`` for allowed keys and their default values. - :param prefetch: (optional) whether to immediately download the response - content. Defaults to ``True``. + :param stream: (optional) whether to immediately download the response + content. Defaults to ``False``. :param verify: (optional) if ``True``, the SSL cert will be verified. A CA_BUNDLE path can also be provided. :param cert: (optional) if String, path to ssl client cert file (.pem). @@ -294,7 +274,8 @@ class Session(SessionRedirectMixin): cookies = cookiejar_from_dict(cookies) # Merge with session cookies - merged_cookies = self.cookies.copy() + merged_cookies = RequestsCookieJar() + merged_cookies.update(self.cookies) merged_cookies.update(cookies) cookies = merged_cookies @@ -318,14 +299,14 @@ class Session(SessionRedirectMixin): verify = os.environ.get('CURL_CA_BUNDLE') # Merge all the kwargs. - params = merge_kwargs(params, self.params) - headers = merge_kwargs(headers, self.headers) - auth = merge_kwargs(auth, self.auth) - proxies = merge_kwargs(proxies, self.proxies) - hooks = merge_kwargs(hooks, self.hooks) - stream = merge_kwargs(stream, self.stream) - verify = merge_kwargs(verify, self.verify) - cert = merge_kwargs(cert, self.cert) + params = merge_setting(params, self.params) + headers = merge_setting(headers, self.headers, dict_class=CaseInsensitiveDict) + auth = merge_setting(auth, self.auth) + proxies = merge_setting(proxies, self.proxies) + hooks = merge_setting(hooks, self.hooks) + stream = merge_setting(stream, self.stream) + verify = merge_setting(verify, self.verify) + cert = merge_setting(cert, self.cert) # Create the Request. req = Request() @@ -353,9 +334,6 @@ class Session(SessionRedirectMixin): } resp = self.send(prep, **send_kwargs) - # Persist cookies. - self.cookies.update(resp.cookies) - return resp def get(self, url, **kwargs): @@ -464,6 +442,9 @@ class Session(SessionRedirectMixin): # Response manipulation hooks r = dispatch_hook('response', hooks, r, **kwargs) + # Persist cookies + extract_cookies_to_jar(self.cookies, request, r.raw) + # Redirect resolving generator. gen = self.resolve_redirects(r, request, stream=stream, timeout=timeout, verify=verify, cert=cert, @@ -498,8 +479,13 @@ class Session(SessionRedirectMixin): v.close() def mount(self, prefix, adapter): - """Registers a connection adapter to a prefix.""" + """Registers a connection adapter to a prefix. + + Adapters are sorted in descending order by key length.""" self.adapters[prefix] = adapter + keys_to_move = [k for k in self.adapters if len(k) < len(prefix)] + for key in keys_to_move: + self.adapters[key] = self.adapters.pop(key) def __getstate__(self): return dict((attr, getattr(self, attr, None)) for attr in self.__attrs__) diff --git a/requests/status_codes.py b/requests/status_codes.py index 08edab4..de38486 100644 --- a/requests/status_codes.py +++ b/requests/status_codes.py @@ -62,6 +62,7 @@ _codes = { 444: ('no_response', 'none'), 449: ('retry_with', 'retry'), 450: ('blocked_by_windows_parental_controls', 'parental_controls'), + 451: ('unavailable_for_legal_reasons', 'legal_reasons'), 499: ('client_closed_request',), # Server Error. diff --git a/requests/structures.py b/requests/structures.py index 05f5ac1..8d02ea6 100644 --- a/requests/structures.py +++ b/requests/structures.py @@ -9,6 +9,7 @@ Data structures that power Requests. """ import os +import collections from itertools import islice @@ -33,43 +34,79 @@ class IteratorProxy(object): return "".join(islice(self.i, None, n)) -class CaseInsensitiveDict(dict): - """Case-insensitive Dictionary +class CaseInsensitiveDict(collections.MutableMapping): + """ + A case-insensitive ``dict``-like object. + + Implements all methods and operations of + ``collections.MutableMapping`` as well as dict's ``copy``. Also + provides ``lower_items``. + + All keys are expected to be strings. The structure remembers the + case of the last key to be set, and ``iter(instance)``, + ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()`` + will contain case-sensitive keys. However, querying and contains + testing is case insensitive: + + cid = CaseInsensitiveDict() + cid['Accept'] = 'application/json' + cid['aCCEPT'] == 'application/json' # True + list(cid) == ['Accept'] # True For example, ``headers['content-encoding']`` will return the - value of a ``'Content-Encoding'`` response header.""" + value of a ``'Content-Encoding'`` response header, regardless + of how the header name was originally stored. - @property - def lower_keys(self): - if not hasattr(self, '_lower_keys') or not self._lower_keys: - self._lower_keys = dict((k.lower(), k) for k in list(self.keys())) - return self._lower_keys + If the constructor, ``.update``, or equality comparison + operations are given keys that have equal ``.lower()``s, the + behavior is undefined. - def _clear_lower_keys(self): - if hasattr(self, '_lower_keys'): - self._lower_keys.clear() + """ + def __init__(self, data=None, **kwargs): + self._store = dict() + if data is None: + data = {} + self.update(data, **kwargs) def __setitem__(self, key, value): - dict.__setitem__(self, key, value) - self._clear_lower_keys() + # Use the lowercased key for lookups, but store the actual + # key alongside the value. + self._store[key.lower()] = (key, value) - def __delitem__(self, key): - dict.__delitem__(self, self.lower_keys.get(key.lower(), key)) - self._lower_keys.clear() + def __getitem__(self, key): + return self._store[key.lower()][1] - def __contains__(self, key): - return key.lower() in self.lower_keys + def __delitem__(self, key): + del self._store[key.lower()] - def __getitem__(self, key): - # We allow fall-through here, so values default to None - if key in self: - return dict.__getitem__(self, self.lower_keys[key.lower()]) + def __iter__(self): + return (casedkey for casedkey, mappedvalue in self._store.values()) - def get(self, key, default=None): - if key in self: - return self[key] + def __len__(self): + return len(self._store) + + def lower_items(self): + """Like iteritems(), but with all lowercase keys.""" + return ( + (lowerkey, keyval[1]) + for (lowerkey, keyval) + in self._store.items() + ) + + def __eq__(self, other): + if isinstance(other, collections.Mapping): + other = CaseInsensitiveDict(other) else: - return default + return NotImplemented + # Compare insensitively + return dict(self.lower_items()) == dict(other.lower_items()) + + # Copy is required + def copy(self): + return CaseInsensitiveDict(self._store.values()) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, dict(self.items())) class LookupDict(dict): diff --git a/requests/utils.py b/requests/utils.py index a2d434e..b21bf8f 100644 --- a/requests/utils.py +++ b/requests/utils.py @@ -11,11 +11,11 @@ that are also useful for external consumption. import cgi import codecs +import collections import os import platform import re import sys -import zlib from netrc import netrc, NetrcParseError from . import __version__ @@ -23,6 +23,7 @@ from . import certs from .compat import parse_http_list as _parse_list_header from .compat import quote, urlparse, bytes, str, OrderedDict, urlunparse from .cookies import RequestsCookieJar, cookiejar_from_dict +from .structures import CaseInsensitiveDict _hush_pyflakes = (RequestsCookieJar,) @@ -134,7 +135,7 @@ def to_key_val_list(value): if isinstance(value, (str, bytes, bool, int)): raise ValueError('cannot encode objects that are not 2-tuples') - if isinstance(value, dict): + if isinstance(value, collections.Mapping): value = value.items() return list(value) @@ -346,48 +347,6 @@ def get_unicode_from_response(r): return r.content -def stream_decompress(iterator, mode='gzip'): - """Stream decodes an iterator over compressed data - - :param iterator: An iterator over compressed data - :param mode: 'gzip' or 'deflate' - :return: An iterator over decompressed data - """ - - if mode not in ['gzip', 'deflate']: - raise ValueError('stream_decompress mode must be gzip or deflate') - - zlib_mode = 16 + zlib.MAX_WBITS if mode == 'gzip' else -zlib.MAX_WBITS - dec = zlib.decompressobj(zlib_mode) - try: - for chunk in iterator: - rv = dec.decompress(chunk) - if rv: - yield rv - except zlib.error: - # If there was an error decompressing, just return the raw chunk - yield chunk - # Continue to return the rest of the raw data - for chunk in iterator: - yield chunk - else: - # Make sure everything has been returned from the decompression object - buf = dec.decompress(bytes()) - rv = buf + dec.flush() - if rv: - yield rv - - -def stream_untransfer(gen, resp): - ce = resp.headers.get('content-encoding', '').lower() - if 'gzip' in ce: - gen = stream_decompress(gen, mode='gzip') - elif 'deflate' in ce: - gen = stream_decompress(gen, mode='deflate') - - return gen - - # The unreserved URI characters (RFC 3986) UNRESERVED_SET = frozenset( "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" @@ -491,11 +450,11 @@ def default_user_agent(): def default_headers(): - return { + return CaseInsensitiveDict({ 'User-Agent': default_user_agent(), 'Accept-Encoding': ', '.join(('gzip', 'deflate', 'compress')), 'Accept': '*/*' - } + }) def parse_header_links(value): -- cgit v1.2.3