diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/__init__.py | 3 | ||||
-rw-r--r-- | test/test_collections.py | 222 | ||||
-rw-r--r-- | test/test_connectionpool.py | 20 | ||||
-rw-r--r-- | test/test_no_ssl.py | 89 | ||||
-rw-r--r-- | test/test_poolmanager.py | 16 | ||||
-rw-r--r-- | test/test_response.py | 246 | ||||
-rw-r--r-- | test/test_util.py | 6 | ||||
-rw-r--r-- | test/with_dummyserver/test_connectionpool.py | 43 | ||||
-rw-r--r-- | test/with_dummyserver/test_https.py | 57 | ||||
-rw-r--r-- | test/with_dummyserver/test_no_ssl.py | 29 | ||||
-rw-r--r-- | test/with_dummyserver/test_poolmanager.py | 29 | ||||
-rw-r--r-- | test/with_dummyserver/test_socketlevel.py | 86 |
12 files changed, 791 insertions, 55 deletions
diff --git a/test/__init__.py b/test/__init__.py index d56a4d3..2fce71c 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -13,7 +13,8 @@ from urllib3.packages import six # Reset. SO suggests this hostname TARPIT_HOST = '10.255.255.1' -VALID_SOURCE_ADDRESSES = [('::1', 0), ('127.0.0.1', 0)] +# (Arguments for socket, is it IPv6 address?) +VALID_SOURCE_ADDRESSES = [(('::1', 0), True), (('127.0.0.1', 0), False)] # RFC 5737: 192.0.2.0/24 is for testing only. # RFC 3849: 2001:db8::/32 is for documentation only. INVALID_SOURCE_ADDRESSES = [('192.0.2.255', 0), ('2001:db8::1', 0)] diff --git a/test/test_collections.py b/test/test_collections.py index 4d173ac..0b36512 100644 --- a/test/test_collections.py +++ b/test/test_collections.py @@ -7,6 +7,8 @@ from urllib3._collections import ( from urllib3.packages import six xrange = six.moves.xrange +from nose.plugins.skip import SkipTest + class TestLRUContainer(unittest.TestCase): def test_maxsize(self): @@ -125,56 +127,216 @@ class TestLRUContainer(unittest.TestCase): self.assertRaises(NotImplementedError, d.__iter__) -class TestHTTPHeaderDict(unittest.TestCase): - def setUp(self): - self.d = HTTPHeaderDict(A='foo') - self.d.add('a', 'bar') +class NonMappingHeaderContainer(object): + def __init__(self, **kwargs): + self._data = {} + self._data.update(kwargs) - def test_overwriting_with_setitem_replaces(self): - d = HTTPHeaderDict() + def keys(self): + return self._data.keys() - d['A'] = 'foo' - self.assertEqual(d['a'], 'foo') + def __getitem__(self, key): + return self._data[key] - d['a'] = 'bar' - self.assertEqual(d['A'], 'bar') + +class TestHTTPHeaderDict(unittest.TestCase): + def setUp(self): + self.d = HTTPHeaderDict(Cookie='foo') + self.d.add('cookie', 'bar') + + def test_create_from_kwargs(self): + h = HTTPHeaderDict(ab=1, cd=2, ef=3, gh=4) + self.assertEqual(len(h), 4) + self.assertTrue('ab' in h) + + def test_create_from_dict(self): + h = HTTPHeaderDict(dict(ab=1, cd=2, ef=3, gh=4)) + self.assertEqual(len(h), 4) + self.assertTrue('ab' in h) + + def test_create_from_iterator(self): + teststr = 'urllib3ontherocks' + h = HTTPHeaderDict((c, c*5) for c in teststr) + self.assertEqual(len(h), len(set(teststr))) + + def test_create_from_list(self): + h = HTTPHeaderDict([('ab', 'A'), ('cd', 'B'), ('cookie', 'C'), ('cookie', 'D'), ('cookie', 'E')]) + self.assertEqual(len(h), 3) + self.assertTrue('ab' in h) + clist = h.getlist('cookie') + self.assertEqual(len(clist), 3) + self.assertEqual(clist[0], 'C') + self.assertEqual(clist[-1], 'E') + + def test_create_from_headerdict(self): + org = HTTPHeaderDict([('ab', 'A'), ('cd', 'B'), ('cookie', 'C'), ('cookie', 'D'), ('cookie', 'E')]) + h = HTTPHeaderDict(org) + self.assertEqual(len(h), 3) + self.assertTrue('ab' in h) + clist = h.getlist('cookie') + self.assertEqual(len(clist), 3) + self.assertEqual(clist[0], 'C') + self.assertEqual(clist[-1], 'E') + self.assertFalse(h is org) + self.assertEqual(h, org) + + def test_setitem(self): + self.d['Cookie'] = 'foo' + self.assertEqual(self.d['cookie'], 'foo') + self.d['cookie'] = 'with, comma' + self.assertEqual(self.d.getlist('cookie'), ['with, comma']) + + def test_update(self): + self.d.update(dict(Cookie='foo')) + self.assertEqual(self.d['cookie'], 'foo') + self.d.update(dict(cookie='with, comma')) + self.assertEqual(self.d.getlist('cookie'), ['with, comma']) + + def test_delitem(self): + del self.d['cookie'] + self.assertFalse('cookie' in self.d) + self.assertFalse('COOKIE' in self.d) + + def test_add_well_known_multiheader(self): + self.d.add('COOKIE', 'asdf') + self.assertEqual(self.d.getlist('cookie'), ['foo', 'bar', 'asdf']) + self.assertEqual(self.d['cookie'], 'foo, bar, asdf') + + def test_add_comma_separated_multiheader(self): + self.d.add('bar', 'foo') + self.d.add('BAR', 'bar') + self.d.add('Bar', 'asdf') + self.assertEqual(self.d.getlist('bar'), ['foo', 'bar', 'asdf']) + self.assertEqual(self.d['bar'], 'foo, bar, asdf') + + def test_extend_from_list(self): + self.d.extend([('set-cookie', '100'), ('set-cookie', '200'), ('set-cookie', '300')]) + self.assertEqual(self.d['set-cookie'], '100, 200, 300') + + def test_extend_from_dict(self): + self.d.extend(dict(cookie='asdf'), b='100') + self.assertEqual(self.d['cookie'], 'foo, bar, asdf') + self.assertEqual(self.d['b'], '100') + self.d.add('cookie', 'with, comma') + self.assertEqual(self.d.getlist('cookie'), ['foo', 'bar', 'asdf', 'with, comma']) + + def test_extend_from_container(self): + h = NonMappingHeaderContainer(Cookie='foo', e='foofoo') + self.d.extend(h) + self.assertEqual(self.d['cookie'], 'foo, bar, foo') + self.assertEqual(self.d['e'], 'foofoo') + self.assertEqual(len(self.d), 2) + + def test_extend_from_headerdict(self): + h = HTTPHeaderDict(Cookie='foo', e='foofoo') + self.d.extend(h) + self.assertEqual(self.d['cookie'], 'foo, bar, foo') + self.assertEqual(self.d['e'], 'foofoo') + self.assertEqual(len(self.d), 2) def test_copy(self): h = self.d.copy() self.assertTrue(self.d is not h) - self.assertEqual(self.d, h) - - def test_add(self): - d = HTTPHeaderDict() - - d['A'] = 'foo' - d.add('a', 'bar') - - self.assertEqual(d['a'], 'foo, bar') - self.assertEqual(d['A'], 'foo, bar') + self.assertEqual(self.d, h) def test_getlist(self): - self.assertEqual(self.d.getlist('a'), ['foo', 'bar']) - self.assertEqual(self.d.getlist('A'), ['foo', 'bar']) + self.assertEqual(self.d.getlist('cookie'), ['foo', 'bar']) + self.assertEqual(self.d.getlist('Cookie'), ['foo', 'bar']) self.assertEqual(self.d.getlist('b'), []) + self.d.add('b', 'asdf') + self.assertEqual(self.d.getlist('b'), ['asdf']) - def test_delitem(self): - del self.d['a'] - self.assertFalse('a' in self.d) - self.assertFalse('A' in self.d) + def test_getlist_after_copy(self): + self.assertEqual(self.d.getlist('cookie'), HTTPHeaderDict(self.d).getlist('cookie')) def test_equal(self): - b = HTTPHeaderDict({'a': 'foo, bar'}) + b = HTTPHeaderDict(cookie='foo, bar') + c = NonMappingHeaderContainer(cookie='foo, bar') self.assertEqual(self.d, b) - c = [('a', 'foo, bar')] - self.assertNotEqual(self.d, c) + self.assertEqual(self.d, c) + self.assertNotEqual(self.d, 2) + + def test_not_equal(self): + b = HTTPHeaderDict(cookie='foo, bar') + c = NonMappingHeaderContainer(cookie='foo, bar') + self.assertFalse(self.d != b) + self.assertFalse(self.d != c) + self.assertNotEqual(self.d, 2) + + def test_pop(self): + key = 'Cookie' + a = self.d[key] + b = self.d.pop(key) + self.assertEqual(a, b) + self.assertFalse(key in self.d) + self.assertRaises(KeyError, self.d.pop, key) + dummy = object() + self.assertTrue(dummy is self.d.pop(key, dummy)) + + def test_discard(self): + self.d.discard('cookie') + self.assertFalse('cookie' in self.d) + self.d.discard('cookie') def test_len(self): self.assertEqual(len(self.d), 1) + self.d.add('cookie', 'bla') + self.d.add('asdf', 'foo') + # len determined by unique fieldnames + self.assertEqual(len(self.d), 2) def test_repr(self): - rep = "HTTPHeaderDict({'A': 'foo, bar'})" + rep = "HTTPHeaderDict({'Cookie': 'foo, bar'})" self.assertEqual(repr(self.d), rep) + def test_items(self): + items = self.d.items() + self.assertEqual(len(items), 2) + self.assertEqual(items[0][0], 'Cookie') + self.assertEqual(items[0][1], 'foo') + self.assertEqual(items[1][0], 'Cookie') + self.assertEqual(items[1][1], 'bar') + + def test_dict_conversion(self): + # Also tested in connectionpool, needs to preserve case + hdict = {'Content-Length': '0', 'Content-type': 'text/plain', 'Server': 'TornadoServer/1.2.3'} + h = dict(HTTPHeaderDict(hdict).items()) + self.assertEqual(hdict, h) + + def test_string_enforcement(self): + # This currently throws AttributeError on key.lower(), should probably be something nicer + self.assertRaises(Exception, self.d.__setitem__, 3, 5) + self.assertRaises(Exception, self.d.add, 3, 4) + self.assertRaises(Exception, self.d.__delitem__, 3) + self.assertRaises(Exception, HTTPHeaderDict, {3: 3}) + + def test_from_httplib_py2(self): + if six.PY3: + raise SkipTest("python3 has a different internal header implementation") + msg = """ +Server: nginx +Content-Type: text/html; charset=windows-1251 +Connection: keep-alive +X-Some-Multiline: asdf + asdf + asdf +Set-Cookie: bb_lastvisit=1348253375; expires=Sat, 21-Sep-2013 18:49:35 GMT; path=/ +Set-Cookie: bb_lastactivity=0; expires=Sat, 21-Sep-2013 18:49:35 GMT; path=/ +www-authenticate: asdf +www-authenticate: bla + +""" + buffer = six.moves.StringIO(msg.lstrip().replace('\n', '\r\n')) + msg = six.moves.http_client.HTTPMessage(buffer) + d = HTTPHeaderDict.from_httplib(msg) + self.assertEqual(d['server'], 'nginx') + cookies = d.getlist('set-cookie') + self.assertEqual(len(cookies), 2) + self.assertTrue(cookies[0].startswith("bb_lastvisit")) + self.assertTrue(cookies[1].startswith("bb_lastactivity")) + self.assertEqual(d['x-some-multiline'].split(), ['asdf', 'asdf', 'asdf']) + self.assertEqual(d['www-authenticate'], 'asdf, bla') + self.assertEqual(d.getlist('www-authenticate'), ['asdf', 'bla']) + if __name__ == '__main__': unittest.main() diff --git a/test/test_connectionpool.py b/test/test_connectionpool.py index a6dbcf4..0718b0f 100644 --- a/test/test_connectionpool.py +++ b/test/test_connectionpool.py @@ -205,6 +205,26 @@ class TestConnectionPool(unittest.TestCase): def test_no_host(self): self.assertRaises(LocationValueError, HTTPConnectionPool, None) + def test_contextmanager(self): + with connection_from_url('http://google.com:80') as pool: + # Populate with some connections + conn1 = pool._get_conn() + conn2 = pool._get_conn() + conn3 = pool._get_conn() + pool._put_conn(conn1) + pool._put_conn(conn2) + + old_pool_queue = pool.pool + + self.assertEqual(pool.pool, None) + + self.assertRaises(ClosedPoolError, pool._get_conn) + + pool._put_conn(conn3) + + self.assertRaises(ClosedPoolError, pool._get_conn) + + self.assertRaises(Empty, old_pool_queue.get, block=False) if __name__ == '__main__': diff --git a/test/test_no_ssl.py b/test/test_no_ssl.py new file mode 100644 index 0000000..b5961b8 --- /dev/null +++ b/test/test_no_ssl.py @@ -0,0 +1,89 @@ +""" +Test what happens if Python was built without SSL + +* Everything that does not involve HTTPS should still work +* HTTPS requests must fail with an error that points at the ssl module +""" + +import sys +import unittest + + +class ImportBlocker(object): + """ + Block Imports + + To be placed on ``sys.meta_path``. This ensures that the modules + specified cannot be imported, even if they are a builtin. + """ + def __init__(self, *namestoblock): + self.namestoblock = namestoblock + + def find_module(self, fullname, path=None): + if fullname in self.namestoblock: + return self + return None + + def load_module(self, fullname): + raise ImportError('import of {0} is blocked'.format(fullname)) + + +class ModuleStash(object): + """ + Stashes away previously imported modules + + If we reimport a module the data from coverage is lost, so we reuse the old + modules + """ + + def __init__(self, namespace, modules=sys.modules): + self.namespace = namespace + self.modules = modules + self._data = {} + + def stash(self): + self._data[self.namespace] = self.modules.pop(self.namespace, None) + + for module in list(self.modules.keys()): + if module.startswith(self.namespace + '.'): + self._data[module] = self.modules.pop(module) + + def pop(self): + self.modules.pop(self.namespace, None) + + for module in list(self.modules.keys()): + if module.startswith(self.namespace + '.'): + self.modules.pop(module) + + self.modules.update(self._data) + + +ssl_blocker = ImportBlocker('ssl', '_ssl') +module_stash = ModuleStash('urllib3') + + +class TestWithoutSSL(unittest.TestCase): + def setUp(self): + sys.modules.pop('ssl', None) + sys.modules.pop('_ssl', None) + + module_stash.stash() + sys.meta_path.insert(0, ssl_blocker) + + def tearDown(self): + assert sys.meta_path.pop(0) == ssl_blocker + module_stash.pop() + + +class TestImportWithoutSSL(TestWithoutSSL): + def test_cannot_import_ssl(self): + # python26 has neither contextmanagers (for assertRaises) nor + # importlib. + # 'import' inside 'lambda' is invalid syntax. + def import_ssl(): + import ssl + + self.assertRaises(ImportError, import_ssl) + + def test_import_urllib3(self): + import urllib3 diff --git a/test/test_poolmanager.py b/test/test_poolmanager.py index 754ee8a..6195d51 100644 --- a/test/test_poolmanager.py +++ b/test/test_poolmanager.py @@ -71,6 +71,22 @@ class TestPoolManager(unittest.TestCase): self.assertRaises(LocationValueError, p.connection_from_url, 'http://@') self.assertRaises(LocationValueError, p.connection_from_url, None) + def test_contextmanager(self): + with PoolManager(1) as p: + conn_pool = p.connection_from_url('http://google.com') + self.assertEqual(len(p.pools), 1) + conn = conn_pool._get_conn() + + self.assertEqual(len(p.pools), 0) + + self.assertRaises(ClosedPoolError, conn_pool._get_conn) + + conn_pool._put_conn(conn) + + self.assertRaises(ClosedPoolError, conn_pool._get_conn) + + self.assertEqual(len(p.pools), 0) + if __name__ == '__main__': unittest.main() diff --git a/test/test_response.py b/test/test_response.py index 7d67c93..2e2be0e 100644 --- a/test/test_response.py +++ b/test/test_response.py @@ -2,8 +2,12 @@ import unittest from io import BytesIO, BufferedReader +try: + import http.client as httplib +except ImportError: + import httplib from urllib3.response import HTTPResponse -from urllib3.exceptions import DecodeError +from urllib3.exceptions import DecodeError, ResponseNotChunked from base64 import b64decode @@ -73,6 +77,15 @@ class TestResponse(unittest.TestCase): 'content-encoding': 'deflate' }) + def test_reference_read(self): + fp = BytesIO(b'foo') + r = HTTPResponse(fp, preload_content=False) + + self.assertEqual(r.read(1), b'f') + self.assertEqual(r.read(2), b'oo') + self.assertEqual(r.read(), b'') + self.assertEqual(r.read(), b'') + def test_decode_deflate(self): import zlib data = zlib.compress(b'foo') @@ -102,6 +115,9 @@ class TestResponse(unittest.TestCase): self.assertEqual(r.read(3), b'') self.assertEqual(r.read(1), b'f') self.assertEqual(r.read(2), b'oo') + self.assertEqual(r.read(), b'') + self.assertEqual(r.read(), b'') + def test_chunked_decoding_deflate2(self): import zlib @@ -116,6 +132,9 @@ class TestResponse(unittest.TestCase): self.assertEqual(r.read(1), b'') self.assertEqual(r.read(1), b'f') self.assertEqual(r.read(2), b'oo') + self.assertEqual(r.read(), b'') + self.assertEqual(r.read(), b'') + def test_chunked_decoding_gzip(self): import zlib @@ -130,6 +149,9 @@ class TestResponse(unittest.TestCase): self.assertEqual(r.read(11), b'') self.assertEqual(r.read(1), b'f') self.assertEqual(r.read(2), b'oo') + self.assertEqual(r.read(), b'') + self.assertEqual(r.read(), b'') + def test_body_blob(self): resp = HTTPResponse(b'foo') @@ -138,10 +160,6 @@ class TestResponse(unittest.TestCase): def test_io(self): import socket - try: - from http.client import HTTPResponse as OldHTTPResponse - except: - from httplib import HTTPResponse as OldHTTPResponse fp = BytesIO(b'foo') resp = HTTPResponse(fp, preload_content=False) @@ -156,7 +174,7 @@ class TestResponse(unittest.TestCase): # Try closing with an `httplib.HTTPResponse`, because it has an # `isclosed` method. - hlr = OldHTTPResponse(socket.socket()) + hlr = httplib.HTTPResponse(socket.socket()) resp2 = HTTPResponse(hlr, preload_content=False) self.assertEqual(resp2.closed, False) resp2.close() @@ -388,11 +406,227 @@ class TestResponse(unittest.TestCase): self.assertEqual(next(stream), b'o') self.assertRaises(StopIteration, next, stream) + def test_mock_transfer_encoding_chunked(self): + stream = [b"fo", b"o", b"bar"] + fp = MockChunkedEncodingResponse(stream) + r = httplib.HTTPResponse(MockSock) + r.fp = fp + resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) + + i = 0 + for c in resp.stream(): + self.assertEqual(c, stream[i]) + i += 1 + + def test_mock_gzipped_transfer_encoding_chunked_decoded(self): + """Show that we can decode the gizpped and chunked body.""" + def stream(): + # Set up a generator to chunk the gzipped body + import zlib + compress = zlib.compressobj(6, zlib.DEFLATED, 16 + zlib.MAX_WBITS) + data = compress.compress(b'foobar') + data += compress.flush() + for i in range(0, len(data), 2): + yield data[i:i+2] + + fp = MockChunkedEncodingResponse(list(stream())) + r = httplib.HTTPResponse(MockSock) + r.fp = fp + headers = {'transfer-encoding': 'chunked', 'content-encoding': 'gzip'} + resp = HTTPResponse(r, preload_content=False, headers=headers) + + data = b'' + for c in resp.stream(decode_content=True): + data += c + + self.assertEqual(b'foobar', data) + + def test_mock_transfer_encoding_chunked_custom_read(self): + stream = [b"foooo", b"bbbbaaaaar"] + fp = MockChunkedEncodingResponse(stream) + r = httplib.HTTPResponse(MockSock) + r.fp = fp + r.chunked = True + r.chunk_left = None + resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) + expected_response = [b'fo', b'oo', b'o', b'bb', b'bb', b'aa', b'aa', b'ar'] + response = list(resp.read_chunked(2)) + if getattr(self, "assertListEqual", False): + self.assertListEqual(expected_response, response) + else: + for index, item in enumerate(response): + v = expected_response[index] + self.assertEqual(item, v) + + def test_mock_transfer_encoding_chunked_unlmtd_read(self): + stream = [b"foooo", b"bbbbaaaaar"] + fp = MockChunkedEncodingResponse(stream) + r = httplib.HTTPResponse(MockSock) + r.fp = fp + r.chunked = True + r.chunk_left = None + resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) + if getattr(self, "assertListEqual", False): + self.assertListEqual(stream, list(resp.read_chunked())) + else: + for index, item in enumerate(resp.read_chunked()): + v = stream[index] + self.assertEqual(item, v) + + def test_read_not_chunked_response_as_chunks(self): + fp = BytesIO(b'foo') + resp = HTTPResponse(fp, preload_content=False) + r = resp.read_chunked() + self.assertRaises(ResponseNotChunked, next, r) + + def test_invalid_chunks(self): + stream = [b"foooo", b"bbbbaaaaar"] + fp = MockChunkedInvalidEncoding(stream) + r = httplib.HTTPResponse(MockSock) + r.fp = fp + r.chunked = True + r.chunk_left = None + resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) + self.assertRaises(httplib.IncompleteRead, next, resp.read_chunked()) + + def test_chunked_response_without_crlf_on_end(self): + stream = [b"foo", b"bar", b"baz"] + fp = MockChunkedEncodingWithoutCRLFOnEnd(stream) + r = httplib.HTTPResponse(MockSock) + r.fp = fp + r.chunked = True + r.chunk_left = None + resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) + if getattr(self, "assertListEqual", False): + self.assertListEqual(stream, list(resp.stream())) + else: + for index, item in enumerate(resp.stream()): + v = stream[index] + self.assertEqual(item, v) + + def test_chunked_response_with_extensions(self): + stream = [b"foo", b"bar"] + fp = MockChunkedEncodingWithExtensions(stream) + r = httplib.HTTPResponse(MockSock) + r.fp = fp + r.chunked = True + r.chunk_left = None + resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) + if getattr(self, "assertListEqual", False): + self.assertListEqual(stream, list(resp.stream())) + else: + for index, item in enumerate(resp.stream()): + v = stream[index] + self.assertEqual(item, v) + def test_get_case_insensitive_headers(self): headers = {'host': 'example.com'} r = HTTPResponse(headers=headers) self.assertEqual(r.headers.get('host'), 'example.com') self.assertEqual(r.headers.get('Host'), 'example.com') + +class MockChunkedEncodingResponse(object): + + def __init__(self, content): + """ + content: collection of str, each str is a chunk in response + """ + self.content = content + self.index = 0 # This class iterates over self.content. + self.closed = False + self.cur_chunk = b'' + self.chunks_exhausted = False + + @staticmethod + def _encode_chunk(chunk): + # In the general case, we can't decode the chunk to unicode + length = '%X\r\n' % len(chunk) + return length.encode() + chunk + b'\r\n' + + def _pop_new_chunk(self): + if self.chunks_exhausted: + return b"" + try: + chunk = self.content[self.index] + except IndexError: + chunk = b'' + self.chunks_exhausted = True + else: + self.index += 1 + chunk = self._encode_chunk(chunk) + if not isinstance(chunk, bytes): + chunk = chunk.encode() + return chunk + + def pop_current_chunk(self, amt=-1, till_crlf=False): + if amt > 0 and till_crlf: + raise ValueError("Can't specify amt and till_crlf.") + if len(self.cur_chunk) <= 0: + self.cur_chunk = self._pop_new_chunk() + if till_crlf: + try: + i = self.cur_chunk.index(b"\r\n") + except ValueError: + # No CRLF in current chunk -- probably caused by encoder. + self.cur_chunk = b"" + return b"" + else: + chunk_part = self.cur_chunk[:i+2] + self.cur_chunk = self.cur_chunk[i+2:] + return chunk_part + elif amt <= -1: + chunk_part = self.cur_chunk + self.cur_chunk = b'' + return chunk_part + else: + try: + chunk_part = self.cur_chunk[:amt] + except IndexError: + chunk_part = self.cur_chunk + self.cur_chunk = b'' + else: + self.cur_chunk = self.cur_chunk[amt:] + return chunk_part + + def readline(self): + return self.pop_current_chunk(till_crlf=True) + + def read(self, amt=-1): + return self.pop_current_chunk(amt) + + def flush(self): + # Python 3 wants this method. + pass + + def close(self): + self.closed = True + + +class MockChunkedInvalidEncoding(MockChunkedEncodingResponse): + + def _encode_chunk(self, chunk): + return 'ZZZ\r\n%s\r\n' % chunk.decode() + + +class MockChunkedEncodingWithoutCRLFOnEnd(MockChunkedEncodingResponse): + + def _encode_chunk(self, chunk): + return '%X\r\n%s%s' % (len(chunk), chunk.decode(), + "\r\n" if len(chunk) > 0 else "") + + +class MockChunkedEncodingWithExtensions(MockChunkedEncodingResponse): + + def _encode_chunk(self, chunk): + return '%X;asd=qwe\r\n%s\r\n' % (len(chunk), chunk.decode()) + + +class MockSock(object): + @classmethod + def makefile(cls, *args, **kwargs): + return + + if __name__ == '__main__': unittest.main() diff --git a/test/test_util.py b/test/test_util.py index c850d91..19ba57e 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -15,7 +15,10 @@ from urllib3.util.url import ( split_first, Url, ) -from urllib3.util.ssl_ import resolve_cert_reqs, ssl_wrap_socket +from urllib3.util.ssl_ import ( + resolve_cert_reqs, + ssl_wrap_socket, +) from urllib3.exceptions import ( LocationParseError, TimeoutStateError, @@ -94,6 +97,7 @@ class TestUtil(unittest.TestCase): parse_url_host_map = { 'http://google.com/mail': Url('http', host='google.com', path='/mail'), 'http://google.com/mail/': Url('http', host='google.com', path='/mail/'), + 'http://google.com/mail': Url('http', host='google.com', path='mail'), 'google.com/mail': Url(host='google.com', path='/mail'), 'http://google.com/': Url('http', host='google.com', path='/'), 'http://google.com': Url('http', host='google.com'), diff --git a/test/with_dummyserver/test_connectionpool.py b/test/with_dummyserver/test_connectionpool.py index cc0f011..d6cb162 100644 --- a/test/with_dummyserver/test_connectionpool.py +++ b/test/with_dummyserver/test_connectionpool.py @@ -4,6 +4,7 @@ import socket import sys import unittest import time +import warnings import mock @@ -35,6 +36,7 @@ from urllib3.util.timeout import Timeout import tornado from dummyserver.testcase import HTTPDummyServerTestCase +from dummyserver.server import NoIPv6Warning from nose.tools import timed @@ -597,7 +599,11 @@ class TestConnectionPool(HTTPDummyServerTestCase): self.assertRaises(MaxRetryError, pool.request, 'GET', '/test', retries=2) def test_source_address(self): - for addr in VALID_SOURCE_ADDRESSES: + for addr, is_ipv6 in VALID_SOURCE_ADDRESSES: + if is_ipv6 and not socket.has_ipv6: + warnings.warn("No IPv6 support: skipping.", + NoIPv6Warning) + continue pool = HTTPConnectionPool(self.host, self.port, source_address=addr, retries=False) r = pool.request('GET', '/source_address') @@ -612,13 +618,34 @@ class TestConnectionPool(HTTPDummyServerTestCase): self.assertRaises(ProtocolError, pool.request, 'GET', '/source_address') - @onlyPy3 - def test_httplib_headers_case_insensitive(self): - HEADERS = {'Content-Length': '0', 'Content-type': 'text/plain', - 'Server': 'TornadoServer/%s' % tornado.version} - r = self.pool.request('GET', '/specific_method', - fields={'method': 'GET'}) - self.assertEqual(HEADERS, dict(r.headers.items())) # to preserve case sensitivity + def test_stream_keepalive(self): + x = 2 + + for _ in range(x): + response = self.pool.request( + 'GET', + '/chunked', + headers={ + 'Connection': 'keep-alive', + }, + preload_content=False, + retries=False, + ) + for chunk in response.stream(): + self.assertEqual(chunk, b'123') + + self.assertEqual(self.pool.num_connections, 1) + self.assertEqual(self.pool.num_requests, x) + + def test_chunked_gzip(self): + response = self.pool.request( + 'GET', + '/chunked_gzip', + preload_content=False, + decode_content=True, + ) + + self.assertEqual(b'123' * 4, response.read()) class TestRetry(HTTPDummyServerTestCase): diff --git a/test/with_dummyserver/test_https.py b/test/with_dummyserver/test_https.py index 16ca589..992b8ef 100644 --- a/test/with_dummyserver/test_https.py +++ b/test/with_dummyserver/test_https.py @@ -30,10 +30,17 @@ from urllib3.exceptions import ( ConnectTimeoutError, InsecureRequestWarning, SystemTimeWarning, + InsecurePlatformWarning, ) +from urllib3.packages import six from urllib3.util.timeout import Timeout +ResourceWarning = getattr( + six.moves.builtins, + 'ResourceWarning', type('ResourceWarning', (), {})) + + log = logging.getLogger('urllib3.connectionpool') log.setLevel(logging.NOTSET) log.addHandler(logging.StreamHandler(sys.stdout)) @@ -64,7 +71,14 @@ class TestHTTPS(HTTPSDummyServerTestCase): with mock.patch('warnings.warn') as warn: r = https_pool.request('GET', '/') self.assertEqual(r.status, 200) - self.assertFalse(warn.called, warn.call_args_list) + + if sys.version_info >= (2, 7, 9): + self.assertFalse(warn.called, warn.call_args_list) + else: + self.assertTrue(warn.called) + call, = warn.call_args_list + error = call[0][1] + self.assertEqual(error, InsecurePlatformWarning) def test_invalid_common_name(self): https_pool = HTTPSConnectionPool('127.0.0.1', self.port, @@ -137,8 +151,11 @@ class TestHTTPS(HTTPSDummyServerTestCase): self.assertEqual(r.status, 200) self.assertTrue(warn.called) - call, = warn.call_args_list - category = call[0][1] + calls = warn.call_args_list + if sys.version_info >= (2, 7, 9): + category = calls[0][0][1] + else: + category = calls[1][0][1] self.assertEqual(category, InsecureRequestWarning) @requires_network @@ -202,6 +219,16 @@ class TestHTTPS(HTTPSDummyServerTestCase): '7A:F2:8A:D7:1E:07:33:67:DE' https_pool.request('GET', '/') + def test_assert_fingerprint_sha256(self): + https_pool = HTTPSConnectionPool('localhost', self.port, + cert_reqs='CERT_REQUIRED', + ca_certs=DEFAULT_CA) + + https_pool.assert_fingerprint = ('9A:29:9D:4F:47:85:1C:51:23:F5:9A:A3:' + '0F:5A:EF:96:F9:2E:3C:22:2E:FC:E8:BC:' + '0E:73:90:37:ED:3B:AA:AB') + https_pool.request('GET', '/') + def test_assert_invalid_fingerprint(self): https_pool = HTTPSConnectionPool('127.0.0.1', self.port, cert_reqs='CERT_REQUIRED', @@ -240,6 +267,15 @@ class TestHTTPS(HTTPSDummyServerTestCase): '7A:F2:8A:D7:1E:07:33:67:DE' https_pool.request('GET', '/') + def test_good_fingerprint_and_hostname_mismatch(self): + https_pool = HTTPSConnectionPool('127.0.0.1', self.port, + cert_reqs='CERT_REQUIRED', + ca_certs=DEFAULT_CA) + + https_pool.assert_fingerprint = 'CC:45:6A:90:82:F7FF:C0:8218:8e:' \ + '7A:F2:8A:D7:1E:07:33:67:DE' + https_pool.request('GET', '/') + @requires_network def test_https_timeout(self): timeout = Timeout(connect=0.001) @@ -332,10 +368,8 @@ class TestHTTPS(HTTPSDummyServerTestCase): def test_ssl_correct_system_time(self): self._pool.cert_reqs = 'CERT_REQUIRED' self._pool.ca_certs = DEFAULT_CA - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - self._pool.request('GET', '/') + w = self._request_without_resource_warnings('GET', '/') self.assertEqual([], w) def test_ssl_wrong_system_time(self): @@ -344,9 +378,7 @@ class TestHTTPS(HTTPSDummyServerTestCase): with mock.patch('urllib3.connection.datetime') as mock_date: mock_date.date.today.return_value = datetime.date(1970, 1, 1) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - self._pool.request('GET', '/') + w = self._request_without_resource_warnings('GET', '/') self.assertEqual(len(w), 1) warning = w[0] @@ -354,6 +386,13 @@ class TestHTTPS(HTTPSDummyServerTestCase): self.assertEqual(SystemTimeWarning, warning.category) self.assertTrue(str(RECENT_DATE) in warning.message.args[0]) + def _request_without_resource_warnings(self, method, url): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + self._pool.request(method, url) + + return [x for x in w if not isinstance(x.message, ResourceWarning)] + class TestHTTPS_TLSv1(HTTPSDummyServerTestCase): certs = DEFAULT_CERTS.copy() diff --git a/test/with_dummyserver/test_no_ssl.py b/test/with_dummyserver/test_no_ssl.py new file mode 100644 index 0000000..f266d49 --- /dev/null +++ b/test/with_dummyserver/test_no_ssl.py @@ -0,0 +1,29 @@ +""" +Test connections without the builtin ssl module + +Note: Import urllib3 inside the test functions to get the importblocker to work +""" +from ..test_no_ssl import TestWithoutSSL + +from dummyserver.testcase import ( + HTTPDummyServerTestCase, HTTPSDummyServerTestCase) + + +class TestHTTPWithoutSSL(HTTPDummyServerTestCase, TestWithoutSSL): + def test_simple(self): + import urllib3 + + pool = urllib3.HTTPConnectionPool(self.host, self.port) + r = pool.request('GET', '/') + self.assertEqual(r.status, 200, r.data) + + +class TestHTTPSWithoutSSL(HTTPSDummyServerTestCase, TestWithoutSSL): + def test_simple(self): + import urllib3 + + pool = urllib3.HTTPSConnectionPool(self.host, self.port) + try: + pool.request('GET', '/') + except urllib3.exceptions.SSLError as e: + self.assertTrue('SSL module is not available' in str(e)) diff --git a/test/with_dummyserver/test_poolmanager.py b/test/with_dummyserver/test_poolmanager.py index 52ff974..7e51c73 100644 --- a/test/with_dummyserver/test_poolmanager.py +++ b/test/with_dummyserver/test_poolmanager.py @@ -6,6 +6,7 @@ from dummyserver.testcase import (HTTPDummyServerTestCase, from urllib3.poolmanager import PoolManager from urllib3.connectionpool import port_by_scheme from urllib3.exceptions import MaxRetryError, SSLError +from urllib3.util.retry import Retry class TestPoolManager(HTTPDummyServerTestCase): @@ -78,6 +79,34 @@ class TestPoolManager(HTTPDummyServerTestCase): self.assertEqual(r._pool.host, self.host_alt) + def test_too_many_redirects(self): + http = PoolManager() + + try: + r = http.request('GET', '%s/redirect' % self.base_url, + fields={'target': '%s/redirect?target=%s/' % (self.base_url, self.base_url)}, + retries=1) + self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) + except MaxRetryError: + pass + + try: + r = http.request('GET', '%s/redirect' % self.base_url, + fields={'target': '%s/redirect?target=%s/' % (self.base_url, self.base_url)}, + retries=Retry(total=None, redirect=1)) + self.fail("Failed to raise MaxRetryError exception, returned %r" % r.status) + except MaxRetryError: + pass + + def test_raise_on_redirect(self): + http = PoolManager() + + r = http.request('GET', '%s/redirect' % self.base_url, + fields={'target': '%s/redirect?target=%s/' % (self.base_url, self.base_url)}, + retries=Retry(total=None, redirect=1, raise_on_redirect=False)) + + self.assertEqual(r.status, 303) + def test_missing_port(self): # Can a URL that lacks an explicit port like ':80' succeed, or # will all such URLs fail with an error? diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py index c1ef1be..6c99653 100644 --- a/test/with_dummyserver/test_socketlevel.py +++ b/test/with_dummyserver/test_socketlevel.py @@ -18,6 +18,8 @@ from dummyserver.testcase import SocketDummyServerTestCase from dummyserver.server import ( DEFAULT_CERTS, DEFAULT_CA, get_unreachable_address) +from .. import onlyPy3 + from nose.plugins.skip import SkipTest from threading import Event import socket @@ -44,6 +46,7 @@ class TestCookies(SocketDummyServerTestCase): pool = HTTPConnectionPool(self.host, self.port) r = pool.request('GET', '/', retries=0) self.assertEqual(r.headers, {'set-cookie': 'foo=1, bar=1'}) + self.assertEqual(r.headers.getlist('set-cookie'), ['foo=1', 'bar=1']) class TestSNI(SocketDummyServerTestCase): @@ -521,6 +524,43 @@ class TestSSL(SocketDummyServerTestCase): finally: timed_out.set() + def test_ssl_failed_fingerprint_verification(self): + def socket_handler(listener): + for i in range(2): + sock = listener.accept()[0] + ssl_sock = ssl.wrap_socket(sock, + server_side=True, + keyfile=DEFAULT_CERTS['keyfile'], + certfile=DEFAULT_CERTS['certfile'], + ca_certs=DEFAULT_CA) + + ssl_sock.send(b'HTTP/1.1 200 OK\r\n' + b'Content-Type: text/plain\r\n' + b'Content-Length: 5\r\n\r\n' + b'Hello') + + ssl_sock.close() + sock.close() + + self._start_server(socket_handler) + # GitHub's fingerprint. Valid, but not matching. + fingerprint = ('A0:C4:A7:46:00:ED:A7:2D:C0:BE:CB' + ':9A:8C:B6:07:CA:58:EE:74:5E') + + def request(): + try: + pool = HTTPSConnectionPool(self.host, self.port, + assert_fingerprint=fingerprint) + response = pool.urlopen('GET', '/', preload_content=False, + timeout=Timeout(connect=1, read=0.001)) + response.read() + finally: + pool.close() + + self.assertRaises(SSLError, request) + # Should not hang, see https://github.com/shazow/urllib3/issues/529 + self.assertRaises(SSLError, request) + def consume_socket(sock, chunks=65536): while not sock.recv(chunks).endswith(b'\r\n\r\n'): @@ -560,3 +600,49 @@ class TestErrorWrapping(SocketDummyServerTestCase): self._start_server(handler) pool = HTTPConnectionPool(self.host, self.port, retries=False) self.assertRaises(ProtocolError, pool.request, 'GET', '/') + +class TestHeaders(SocketDummyServerTestCase): + + @onlyPy3 + def test_httplib_headers_case_insensitive(self): + handler = create_response_handler( + b'HTTP/1.1 200 OK\r\n' + b'Content-Length: 0\r\n' + b'Content-type: text/plain\r\n' + b'\r\n' + ) + self._start_server(handler) + pool = HTTPConnectionPool(self.host, self.port, retries=False) + HEADERS = {'Content-Length': '0', 'Content-type': 'text/plain'} + r = pool.request('GET', '/') + self.assertEqual(HEADERS, dict(r.headers.items())) # to preserve case sensitivity + + +class TestHEAD(SocketDummyServerTestCase): + def test_chunked_head_response_does_not_hang(self): + handler = create_response_handler( + b'HTTP/1.1 200 OK\r\n' + b'Transfer-Encoding: chunked\r\n' + b'Content-type: text/plain\r\n' + b'\r\n' + ) + self._start_server(handler) + pool = HTTPConnectionPool(self.host, self.port, retries=False) + r = pool.request('HEAD', '/', timeout=1, preload_content=False) + + # stream will use the read_chunked method here. + self.assertEqual([], list(r.stream())) + + def test_empty_head_response_does_not_hang(self): + handler = create_response_handler( + b'HTTP/1.1 200 OK\r\n' + b'Content-Length: 256\r\n' + b'Content-type: text/plain\r\n' + b'\r\n' + ) + self._start_server(handler) + pool = HTTPConnectionPool(self.host, self.port, retries=False) + r = pool.request('HEAD', '/', timeout=1, preload_content=False) + + # stream will use the read method here. + self.assertEqual([], list(r.stream())) |