aboutsummaryrefslogtreecommitdiff
path: root/test/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_util.py')
-rw-r--r--test/test_util.py402
1 files changed, 402 insertions, 0 deletions
diff --git a/test/test_util.py b/test/test_util.py
new file mode 100644
index 0000000..c850d91
--- /dev/null
+++ b/test/test_util.py
@@ -0,0 +1,402 @@
+import warnings
+import logging
+import unittest
+import ssl
+from itertools import chain
+
+from mock import patch, Mock
+
+from urllib3 import add_stderr_logger, disable_warnings
+from urllib3.util.request import make_headers
+from urllib3.util.timeout import Timeout
+from urllib3.util.url import (
+ get_host,
+ parse_url,
+ split_first,
+ Url,
+)
+from urllib3.util.ssl_ import resolve_cert_reqs, ssl_wrap_socket
+from urllib3.exceptions import (
+ LocationParseError,
+ TimeoutStateError,
+ InsecureRequestWarning,
+ SSLError,
+)
+
+from urllib3.util import is_fp_closed, ssl_
+
+from . import clear_warnings
+
+# This number represents a time in seconds, it doesn't mean anything in
+# isolation. Setting to a high-ish value to avoid conflicts with the smaller
+# numbers used for timeouts
+TIMEOUT_EPOCH = 1000
+
+class TestUtil(unittest.TestCase):
+ def test_get_host(self):
+ url_host_map = {
+ # Hosts
+ 'http://google.com/mail': ('http', 'google.com', None),
+ 'http://google.com/mail/': ('http', 'google.com', None),
+ 'google.com/mail': ('http', 'google.com', None),
+ 'http://google.com/': ('http', 'google.com', None),
+ 'http://google.com': ('http', 'google.com', None),
+ 'http://www.google.com': ('http', 'www.google.com', None),
+ 'http://mail.google.com': ('http', 'mail.google.com', None),
+ 'http://google.com:8000/mail/': ('http', 'google.com', 8000),
+ 'http://google.com:8000': ('http', 'google.com', 8000),
+ 'https://google.com': ('https', 'google.com', None),
+ 'https://google.com:8000': ('https', 'google.com', 8000),
+ 'http://user:password@127.0.0.1:1234': ('http', '127.0.0.1', 1234),
+ 'http://google.com/foo=http://bar:42/baz': ('http', 'google.com', None),
+ 'http://google.com?foo=http://bar:42/baz': ('http', 'google.com', None),
+ 'http://google.com#foo=http://bar:42/baz': ('http', 'google.com', None),
+
+ # IPv4
+ '173.194.35.7': ('http', '173.194.35.7', None),
+ 'http://173.194.35.7': ('http', '173.194.35.7', None),
+ 'http://173.194.35.7/test': ('http', '173.194.35.7', None),
+ 'http://173.194.35.7:80': ('http', '173.194.35.7', 80),
+ 'http://173.194.35.7:80/test': ('http', '173.194.35.7', 80),
+
+ # IPv6
+ '[2a00:1450:4001:c01::67]': ('http', '[2a00:1450:4001:c01::67]', None),
+ 'http://[2a00:1450:4001:c01::67]': ('http', '[2a00:1450:4001:c01::67]', None),
+ 'http://[2a00:1450:4001:c01::67]/test': ('http', '[2a00:1450:4001:c01::67]', None),
+ 'http://[2a00:1450:4001:c01::67]:80': ('http', '[2a00:1450:4001:c01::67]', 80),
+ 'http://[2a00:1450:4001:c01::67]:80/test': ('http', '[2a00:1450:4001:c01::67]', 80),
+
+ # More IPv6 from http://www.ietf.org/rfc/rfc2732.txt
+ 'http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8000/index.html': ('http', '[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]', 8000),
+ 'http://[1080:0:0:0:8:800:200C:417A]/index.html': ('http', '[1080:0:0:0:8:800:200C:417A]', None),
+ 'http://[3ffe:2a00:100:7031::1]': ('http', '[3ffe:2a00:100:7031::1]', None),
+ 'http://[1080::8:800:200C:417A]/foo': ('http', '[1080::8:800:200C:417A]', None),
+ 'http://[::192.9.5.5]/ipng': ('http', '[::192.9.5.5]', None),
+ 'http://[::FFFF:129.144.52.38]:42/index.html': ('http', '[::FFFF:129.144.52.38]', 42),
+ 'http://[2010:836B:4179::836B:4179]': ('http', '[2010:836B:4179::836B:4179]', None),
+ }
+ for url, expected_host in url_host_map.items():
+ returned_host = get_host(url)
+ self.assertEqual(returned_host, expected_host)
+
+ def test_invalid_host(self):
+ # TODO: Add more tests
+ invalid_host = [
+ 'http://google.com:foo',
+ 'http://::1/',
+ 'http://::1:80/',
+ ]
+
+ for location in invalid_host:
+ self.assertRaises(LocationParseError, get_host, location)
+
+
+ parse_url_host_map = {
+ 'http://google.com/mail': Url('http', host='google.com', path='/mail'),
+ 'http://google.com/mail/': Url('http', host='google.com', path='/mail/'),
+ 'google.com/mail': Url(host='google.com', path='/mail'),
+ 'http://google.com/': Url('http', host='google.com', path='/'),
+ 'http://google.com': Url('http', host='google.com'),
+ 'http://google.com?foo': Url('http', host='google.com', path='', query='foo'),
+
+ # Path/query/fragment
+ '': Url(),
+ '/': Url(path='/'),
+ '#?/!google.com/?foo#bar': Url(path='', fragment='?/!google.com/?foo#bar'),
+ '/foo': Url(path='/foo'),
+ '/foo?bar=baz': Url(path='/foo', query='bar=baz'),
+ '/foo?bar=baz#banana?apple/orange': Url(path='/foo', query='bar=baz', fragment='banana?apple/orange'),
+
+ # Port
+ 'http://google.com/': Url('http', host='google.com', path='/'),
+ 'http://google.com:80/': Url('http', host='google.com', port=80, path='/'),
+ 'http://google.com:80': Url('http', host='google.com', port=80),
+
+ # Auth
+ 'http://foo:bar@localhost/': Url('http', auth='foo:bar', host='localhost', path='/'),
+ 'http://foo@localhost/': Url('http', auth='foo', host='localhost', path='/'),
+ 'http://foo:bar@baz@localhost/': Url('http', auth='foo:bar@baz', host='localhost', path='/'),
+ 'http://@': Url('http', host=None, auth='')
+ }
+
+ non_round_tripping_parse_url_host_map = {
+ # Path/query/fragment
+ '?': Url(path='', query=''),
+ '#': Url(path='', fragment=''),
+
+ # Empty Port
+ 'http://google.com:': Url('http', host='google.com'),
+ 'http://google.com:/': Url('http', host='google.com', path='/'),
+
+ }
+
+ def test_parse_url(self):
+ for url, expected_Url in chain(self.parse_url_host_map.items(), self.non_round_tripping_parse_url_host_map.items()):
+ returned_Url = parse_url(url)
+ self.assertEqual(returned_Url, expected_Url)
+
+ def test_unparse_url(self):
+ for url, expected_Url in self.parse_url_host_map.items():
+ self.assertEqual(url, expected_Url.url)
+
+ def test_parse_url_invalid_IPv6(self):
+ self.assertRaises(ValueError, parse_url, '[::1')
+
+ def test_Url_str(self):
+ U = Url('http', host='google.com')
+ self.assertEqual(str(U), U.url)
+
+ def test_request_uri(self):
+ url_host_map = {
+ 'http://google.com/mail': '/mail',
+ 'http://google.com/mail/': '/mail/',
+ 'http://google.com/': '/',
+ 'http://google.com': '/',
+ '': '/',
+ '/': '/',
+ '?': '/?',
+ '#': '/',
+ '/foo?bar=baz': '/foo?bar=baz',
+ }
+ for url, expected_request_uri in url_host_map.items():
+ returned_url = parse_url(url)
+ self.assertEqual(returned_url.request_uri, expected_request_uri)
+
+ def test_netloc(self):
+ url_netloc_map = {
+ 'http://google.com/mail': 'google.com',
+ 'http://google.com:80/mail': 'google.com:80',
+ 'google.com/foobar': 'google.com',
+ 'google.com:12345': 'google.com:12345',
+ }
+
+ for url, expected_netloc in url_netloc_map.items():
+ self.assertEqual(parse_url(url).netloc, expected_netloc)
+
+ def test_make_headers(self):
+ self.assertEqual(
+ make_headers(accept_encoding=True),
+ {'accept-encoding': 'gzip,deflate'})
+
+ self.assertEqual(
+ make_headers(accept_encoding='foo,bar'),
+ {'accept-encoding': 'foo,bar'})
+
+ self.assertEqual(
+ make_headers(accept_encoding=['foo', 'bar']),
+ {'accept-encoding': 'foo,bar'})
+
+ self.assertEqual(
+ make_headers(accept_encoding=True, user_agent='banana'),
+ {'accept-encoding': 'gzip,deflate', 'user-agent': 'banana'})
+
+ self.assertEqual(
+ make_headers(user_agent='banana'),
+ {'user-agent': 'banana'})
+
+ self.assertEqual(
+ make_headers(keep_alive=True),
+ {'connection': 'keep-alive'})
+
+ self.assertEqual(
+ make_headers(basic_auth='foo:bar'),
+ {'authorization': 'Basic Zm9vOmJhcg=='})
+
+ self.assertEqual(
+ make_headers(proxy_basic_auth='foo:bar'),
+ {'proxy-authorization': 'Basic Zm9vOmJhcg=='})
+
+ self.assertEqual(
+ make_headers(disable_cache=True),
+ {'cache-control': 'no-cache'})
+
+ def test_split_first(self):
+ test_cases = {
+ ('abcd', 'b'): ('a', 'cd', 'b'),
+ ('abcd', 'cb'): ('a', 'cd', 'b'),
+ ('abcd', ''): ('abcd', '', None),
+ ('abcd', 'a'): ('', 'bcd', 'a'),
+ ('abcd', 'ab'): ('', 'bcd', 'a'),
+ }
+ for input, expected in test_cases.items():
+ output = split_first(*input)
+ self.assertEqual(output, expected)
+
+ def test_add_stderr_logger(self):
+ handler = add_stderr_logger(level=logging.INFO) # Don't actually print debug
+ logger = logging.getLogger('urllib3')
+ self.assertTrue(handler in logger.handlers)
+
+ logger.debug('Testing add_stderr_logger')
+ logger.removeHandler(handler)
+
+ def test_disable_warnings(self):
+ with warnings.catch_warnings(record=True) as w:
+ clear_warnings()
+ warnings.warn('This is a test.', InsecureRequestWarning)
+ self.assertEqual(len(w), 1)
+ disable_warnings()
+ warnings.warn('This is a test.', InsecureRequestWarning)
+ self.assertEqual(len(w), 1)
+
+ def _make_time_pass(self, seconds, timeout, time_mock):
+ """ Make some time pass for the timeout object """
+ time_mock.return_value = TIMEOUT_EPOCH
+ timeout.start_connect()
+ time_mock.return_value = TIMEOUT_EPOCH + seconds
+ return timeout
+
+ def test_invalid_timeouts(self):
+ try:
+ Timeout(total=-1)
+ self.fail("negative value should throw exception")
+ except ValueError as e:
+ self.assertTrue('less than' in str(e))
+ try:
+ Timeout(connect=2, total=-1)
+ self.fail("negative value should throw exception")
+ except ValueError as e:
+ self.assertTrue('less than' in str(e))
+
+ try:
+ Timeout(read=-1)
+ self.fail("negative value should throw exception")
+ except ValueError as e:
+ self.assertTrue('less than' in str(e))
+
+ # Booleans are allowed also by socket.settimeout and converted to the
+ # equivalent float (1.0 for True, 0.0 for False)
+ Timeout(connect=False, read=True)
+
+ try:
+ Timeout(read="foo")
+ self.fail("string value should not be allowed")
+ except ValueError as e:
+ self.assertTrue('int or float' in str(e))
+
+
+ @patch('urllib3.util.timeout.current_time')
+ def test_timeout(self, current_time):
+ timeout = Timeout(total=3)
+
+ # make 'no time' elapse
+ timeout = self._make_time_pass(seconds=0, timeout=timeout,
+ time_mock=current_time)
+ self.assertEqual(timeout.read_timeout, 3)
+ self.assertEqual(timeout.connect_timeout, 3)
+
+ timeout = Timeout(total=3, connect=2)
+ self.assertEqual(timeout.connect_timeout, 2)
+
+ timeout = Timeout()
+ self.assertEqual(timeout.connect_timeout, Timeout.DEFAULT_TIMEOUT)
+
+ # Connect takes 5 seconds, leaving 5 seconds for read
+ timeout = Timeout(total=10, read=7)
+ timeout = self._make_time_pass(seconds=5, timeout=timeout,
+ time_mock=current_time)
+ self.assertEqual(timeout.read_timeout, 5)
+
+ # Connect takes 2 seconds, read timeout still 7 seconds
+ timeout = Timeout(total=10, read=7)
+ timeout = self._make_time_pass(seconds=2, timeout=timeout,
+ time_mock=current_time)
+ self.assertEqual(timeout.read_timeout, 7)
+
+ timeout = Timeout(total=10, read=7)
+ self.assertEqual(timeout.read_timeout, 7)
+
+ timeout = Timeout(total=None, read=None, connect=None)
+ self.assertEqual(timeout.connect_timeout, None)
+ self.assertEqual(timeout.read_timeout, None)
+ self.assertEqual(timeout.total, None)
+
+ timeout = Timeout(5)
+ self.assertEqual(timeout.total, 5)
+
+
+ def test_timeout_str(self):
+ timeout = Timeout(connect=1, read=2, total=3)
+ self.assertEqual(str(timeout), "Timeout(connect=1, read=2, total=3)")
+ timeout = Timeout(connect=1, read=None, total=3)
+ self.assertEqual(str(timeout), "Timeout(connect=1, read=None, total=3)")
+
+
+ @patch('urllib3.util.timeout.current_time')
+ def test_timeout_elapsed(self, current_time):
+ current_time.return_value = TIMEOUT_EPOCH
+ timeout = Timeout(total=3)
+ self.assertRaises(TimeoutStateError, timeout.get_connect_duration)
+
+ timeout.start_connect()
+ self.assertRaises(TimeoutStateError, timeout.start_connect)
+
+ current_time.return_value = TIMEOUT_EPOCH + 2
+ self.assertEqual(timeout.get_connect_duration(), 2)
+ current_time.return_value = TIMEOUT_EPOCH + 37
+ self.assertEqual(timeout.get_connect_duration(), 37)
+
+ def test_resolve_cert_reqs(self):
+ self.assertEqual(resolve_cert_reqs(None), ssl.CERT_NONE)
+ self.assertEqual(resolve_cert_reqs(ssl.CERT_NONE), ssl.CERT_NONE)
+
+ self.assertEqual(resolve_cert_reqs(ssl.CERT_REQUIRED), ssl.CERT_REQUIRED)
+ self.assertEqual(resolve_cert_reqs('REQUIRED'), ssl.CERT_REQUIRED)
+ self.assertEqual(resolve_cert_reqs('CERT_REQUIRED'), ssl.CERT_REQUIRED)
+
+ def test_is_fp_closed_object_supports_closed(self):
+ class ClosedFile(object):
+ @property
+ def closed(self):
+ return True
+
+ self.assertTrue(is_fp_closed(ClosedFile()))
+
+ def test_is_fp_closed_object_has_none_fp(self):
+ class NoneFpFile(object):
+ @property
+ def fp(self):
+ return None
+
+ self.assertTrue(is_fp_closed(NoneFpFile()))
+
+ def test_is_fp_closed_object_has_fp(self):
+ class FpFile(object):
+ @property
+ def fp(self):
+ return True
+
+ self.assertTrue(not is_fp_closed(FpFile()))
+
+ def test_is_fp_closed_object_has_neither_fp_nor_closed(self):
+ class NotReallyAFile(object):
+ pass
+
+ self.assertRaises(ValueError, is_fp_closed, NotReallyAFile())
+
+ def test_ssl_wrap_socket_loads_the_cert_chain(self):
+ socket = object()
+ mock_context = Mock()
+ ssl_wrap_socket(ssl_context=mock_context, sock=socket,
+ certfile='/path/to/certfile')
+
+ mock_context.load_cert_chain.assert_called_once_with(
+ '/path/to/certfile', None)
+
+ def test_ssl_wrap_socket_loads_verify_locations(self):
+ socket = object()
+ mock_context = Mock()
+ ssl_wrap_socket(ssl_context=mock_context, ca_certs='/path/to/pem',
+ sock=socket)
+ mock_context.load_verify_locations.assert_called_once_with(
+ '/path/to/pem')
+
+ def test_ssl_wrap_socket_with_no_sni(self):
+ socket = object()
+ mock_context = Mock()
+ # Ugly preservation of original value
+ HAS_SNI = ssl_.HAS_SNI
+ ssl_.HAS_SNI = False
+ ssl_wrap_socket(ssl_context=mock_context, sock=socket)
+ mock_context.wrap_socket.assert_called_once_with(socket)
+ ssl_.HAS_SNI = HAS_SNI