aboutsummaryrefslogtreecommitdiff
path: root/test/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_util.py')
-rw-r--r--test/test_util.py121
1 files changed, 83 insertions, 38 deletions
diff --git a/test/test_util.py b/test/test_util.py
index 1811dbd..c850d91 100644
--- a/test/test_util.py
+++ b/test/test_util.py
@@ -2,8 +2,9 @@ import warnings
import logging
import unittest
import ssl
+from itertools import chain
-from mock import patch
+from mock import patch, Mock
from urllib3 import add_stderr_logger, disable_warnings
from urllib3.util.request import make_headers
@@ -14,14 +15,15 @@ from urllib3.util.url import (
split_first,
Url,
)
-from urllib3.util.ssl_ import resolve_cert_reqs
+from urllib3.util.ssl_ import resolve_cert_reqs, ssl_wrap_socket
from urllib3.exceptions import (
LocationParseError,
TimeoutStateError,
InsecureRequestWarning,
+ SSLError,
)
-from urllib3.util import is_fp_closed
+from urllib3.util import is_fp_closed, ssl_
from . import clear_warnings
@@ -89,45 +91,61 @@ class TestUtil(unittest.TestCase):
self.assertRaises(LocationParseError, get_host, location)
- def test_parse_url(self):
- url_host_map = {
- 'http://google.com/mail': Url('http', host='google.com', path='/mail'),
- 'http://google.com/mail/': Url('http', host='google.com', path='/mail/'),
- 'google.com/mail': Url(host='google.com', path='/mail'),
- 'http://google.com/': Url('http', host='google.com', path='/'),
- 'http://google.com': Url('http', host='google.com'),
- 'http://google.com?foo': Url('http', host='google.com', path='', query='foo'),
-
- # Path/query/fragment
- '': Url(),
- '/': Url(path='/'),
- '?': Url(path='', query=''),
- '#': Url(path='', fragment=''),
- '#?/!google.com/?foo#bar': Url(path='', fragment='?/!google.com/?foo#bar'),
- '/foo': Url(path='/foo'),
- '/foo?bar=baz': Url(path='/foo', query='bar=baz'),
- '/foo?bar=baz#banana?apple/orange': Url(path='/foo', query='bar=baz', fragment='banana?apple/orange'),
-
- # Port
- 'http://google.com/': Url('http', host='google.com', path='/'),
- 'http://google.com:80/': Url('http', host='google.com', port=80, path='/'),
- 'http://google.com:/': Url('http', host='google.com', path='/'),
- 'http://google.com:80': Url('http', host='google.com', port=80),
- 'http://google.com:': Url('http', host='google.com'),
-
- # Auth
- 'http://foo:bar@localhost/': Url('http', auth='foo:bar', host='localhost', path='/'),
- 'http://foo@localhost/': Url('http', auth='foo', host='localhost', path='/'),
- 'http://foo:bar@baz@localhost/': Url('http', auth='foo:bar@baz', host='localhost', path='/'),
- 'http://@': Url('http', host=None, auth='')
+ parse_url_host_map = {
+ 'http://google.com/mail': Url('http', host='google.com', path='/mail'),
+ 'http://google.com/mail/': Url('http', host='google.com', path='/mail/'),
+ 'google.com/mail': Url(host='google.com', path='/mail'),
+ 'http://google.com/': Url('http', host='google.com', path='/'),
+ 'http://google.com': Url('http', host='google.com'),
+ 'http://google.com?foo': Url('http', host='google.com', path='', query='foo'),
+
+ # Path/query/fragment
+ '': Url(),
+ '/': Url(path='/'),
+ '#?/!google.com/?foo#bar': Url(path='', fragment='?/!google.com/?foo#bar'),
+ '/foo': Url(path='/foo'),
+ '/foo?bar=baz': Url(path='/foo', query='bar=baz'),
+ '/foo?bar=baz#banana?apple/orange': Url(path='/foo', query='bar=baz', fragment='banana?apple/orange'),
+
+ # Port
+ 'http://google.com/': Url('http', host='google.com', path='/'),
+ 'http://google.com:80/': Url('http', host='google.com', port=80, path='/'),
+ 'http://google.com:80': Url('http', host='google.com', port=80),
+
+ # Auth
+ 'http://foo:bar@localhost/': Url('http', auth='foo:bar', host='localhost', path='/'),
+ 'http://foo@localhost/': Url('http', auth='foo', host='localhost', path='/'),
+ 'http://foo:bar@baz@localhost/': Url('http', auth='foo:bar@baz', host='localhost', path='/'),
+ 'http://@': Url('http', host=None, auth='')
+ }
+
+ non_round_tripping_parse_url_host_map = {
+ # Path/query/fragment
+ '?': Url(path='', query=''),
+ '#': Url(path='', fragment=''),
+
+ # Empty Port
+ 'http://google.com:': Url('http', host='google.com'),
+ 'http://google.com:/': Url('http', host='google.com', path='/'),
+
}
- for url, expected_url in url_host_map.items():
- returned_url = parse_url(url)
- self.assertEqual(returned_url, expected_url)
+
+ def test_parse_url(self):
+ for url, expected_Url in chain(self.parse_url_host_map.items(), self.non_round_tripping_parse_url_host_map.items()):
+ returned_Url = parse_url(url)
+ self.assertEqual(returned_Url, expected_Url)
+
+ def test_unparse_url(self):
+ for url, expected_Url in self.parse_url_host_map.items():
+ self.assertEqual(url, expected_Url.url)
def test_parse_url_invalid_IPv6(self):
self.assertRaises(ValueError, parse_url, '[::1')
+ def test_Url_str(self):
+ U = Url('http', host='google.com')
+ self.assertEqual(str(U), U.url)
+
def test_request_uri(self):
url_host_map = {
'http://google.com/mail': '/mail',
@@ -333,7 +351,7 @@ class TestUtil(unittest.TestCase):
return True
self.assertTrue(is_fp_closed(ClosedFile()))
-
+
def test_is_fp_closed_object_has_none_fp(self):
class NoneFpFile(object):
@property
@@ -355,3 +373,30 @@ class TestUtil(unittest.TestCase):
pass
self.assertRaises(ValueError, is_fp_closed, NotReallyAFile())
+
+ def test_ssl_wrap_socket_loads_the_cert_chain(self):
+ socket = object()
+ mock_context = Mock()
+ ssl_wrap_socket(ssl_context=mock_context, sock=socket,
+ certfile='/path/to/certfile')
+
+ mock_context.load_cert_chain.assert_called_once_with(
+ '/path/to/certfile', None)
+
+ def test_ssl_wrap_socket_loads_verify_locations(self):
+ socket = object()
+ mock_context = Mock()
+ ssl_wrap_socket(ssl_context=mock_context, ca_certs='/path/to/pem',
+ sock=socket)
+ mock_context.load_verify_locations.assert_called_once_with(
+ '/path/to/pem')
+
+ def test_ssl_wrap_socket_with_no_sni(self):
+ socket = object()
+ mock_context = Mock()
+ # Ugly preservation of original value
+ HAS_SNI = ssl_.HAS_SNI
+ ssl_.HAS_SNI = False
+ ssl_wrap_socket(ssl_context=mock_context, sock=socket)
+ mock_context.wrap_socket.assert_called_once_with(socket)
+ ssl_.HAS_SNI = HAS_SNI