diff options
author | Jeremy T. Bouse <jbouse@debian.org> | 2009-11-27 16:20:12 -0500 |
---|---|---|
committer | Jeremy T. Bouse <jbouse@debian.org> | 2009-11-27 16:20:12 -0500 |
commit | ed280d5ac360e2af796e9bd973d7b4df89f0c449 (patch) | |
tree | ce892d6ce9dad8c0ecbc9cbe73f8095195bef0b4 /tests/test_transport.py | |
parent | 176c6caf4ea7918e1698438634b237fab8456471 (diff) | |
download | python-paramiko-ed280d5ac360e2af796e9bd973d7b4df89f0c449.tar python-paramiko-ed280d5ac360e2af796e9bd973d7b4df89f0c449.tar.gz |
Imported Upstream version 1.7.4upstream/1.7.4
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r-- | tests/test_transport.py | 723 |
1 files changed, 438 insertions, 285 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py index 5fcc786..4b52c4f 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -20,12 +20,21 @@ Some unit tests for the ssh2 protocol in Transport. """ -import sys, time, threading, unittest +from binascii import hexlify, unhexlify import select +import socket +import sys +import time +import threading +import unittest +import random + from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ - SSHException, BadAuthenticationType, InteractiveQuery, util + SSHException, BadAuthenticationType, InteractiveQuery, ChannelException from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL -from paramiko import OPEN_SUCCEEDED +from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED +from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST +from paramiko.message import Message from loop import LoopSocket @@ -37,50 +46,16 @@ class NullServer (ServerInterface): def get_allowed_auths(self, username): if username == 'slowdive': return 'publickey,password' - if username == 'paranoid': - if not self.paranoid_did_password and not self.paranoid_did_public_key: - return 'publickey,password' - elif self.paranoid_did_password: - return 'publickey' - else: - return 'password' - if username == 'commie': - return 'keyboard-interactive' return 'publickey' def check_auth_password(self, username, password): if (username == 'slowdive') and (password == 'pygmalion'): return AUTH_SUCCESSFUL - if (username == 'paranoid') and (password == 'paranoid'): - # 2-part auth (even openssh doesn't support this) - self.paranoid_did_password = True - if self.paranoid_did_public_key: - return AUTH_SUCCESSFUL - return AUTH_PARTIALLY_SUCCESSFUL - return AUTH_FAILED - - def check_auth_publickey(self, username, key): - if (username == 'paranoid') and (key == self.paranoid_key): - # 2-part auth - self.paranoid_did_public_key = True - if self.paranoid_did_password: - return AUTH_SUCCESSFUL - return AUTH_PARTIALLY_SUCCESSFUL - return AUTH_FAILED - - def check_auth_interactive(self, username, submethods): - if username == 'commie': - self.username = username - return InteractiveQuery('password', 'Please enter a password.', ('Password', False)) - return AUTH_FAILED - - def check_auth_interactive_response(self, responses): - if self.username == 'commie': - if (len(responses) == 1) and (responses[0] == 'cat'): - return AUTH_SUCCESSFUL return AUTH_FAILED def check_channel_request(self, kind, chanid): + if kind == 'bogus': + return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED return OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): @@ -94,10 +69,34 @@ class NullServer (ServerInterface): def check_global_request(self, kind, msg): self._global_request = kind return False + + def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): + self._x11_single_connection = single_connection + self._x11_auth_protocol = auth_protocol + self._x11_auth_cookie = auth_cookie + self._x11_screen_number = screen_number + return True + + def check_port_forward_request(self, addr, port): + self._listen = socket.socket() + self._listen.bind(('127.0.0.1', 0)) + self._listen.listen(1) + return self._listen.getsockname()[1] + + def cancel_port_forward_request(self, addr, port): + self._listen.close() + self._listen = None + + def check_channel_direct_tcpip_request(self, chanid, origin, destination): + self._tcpip_dest = destination + return OPEN_SUCCEEDED class TransportTest (unittest.TestCase): + assertTrue = unittest.TestCase.failUnless # for Python 2.3 and below + assertFalse = unittest.TestCase.failIf # for Python 2.3 and below + def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() @@ -111,6 +110,26 @@ class TransportTest (unittest.TestCase): self.socks.close() self.sockc.close() + def setup_test_server(self, client_options=None, server_options=None): + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + + if client_options is not None: + client_options(self.tc.get_security_options()) + if server_options is not None: + server_options(self.ts.get_security_options()) + + event = threading.Event() + self.server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, self.server) + self.tc.connect(hostkey=public_host_key, + username='slowdive', password='pygmalion') + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) @@ -130,11 +149,11 @@ class TransportTest (unittest.TestCase): def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L - self.tc.H = util.unhexify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') + self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', - util.hexify(key)) + hexlify(key).upper()) def test_3_simple(self): """ @@ -168,193 +187,45 @@ class TransportTest (unittest.TestCase): verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - options = self.tc.get_security_options() - options.ciphers = ('aes256-cbc',) - options.digests = ('hmac-md5-96',) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + def force_algorithms(options): + options.ciphers = ('aes256-cbc',) + options.digests = ('hmac-md5-96',) + self.setup_test_server(client_options=force_algorithms) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) - self.assert_(self.tc.renegotiate_keys()) + self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ - self.tc.set_hexdump(True) - - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - self.assertEquals(None, getattr(server, '_global_request', None)) + self.setup_test_server() + self.assertEquals(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) - self.assertEquals('keepalive@lag.net', server._global_request) - - def test_6_bad_auth_type(self): - """ - verify that we get the right exception when an unsupported auth - type is requested. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - try: - self.tc.connect(hostkey=public_host_key, - username='unknown', password='error') - self.assert_(False) - except: - etype, evalue, etb = sys.exc_info() - self.assertEquals(BadAuthenticationType, etype) - self.assertEquals(['publickey'], evalue.allowed_types) - - def test_7_bad_password(self): - """ - verify that a bad password gets the right exception, and that a retry - with the right password works. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - try: - self.tc.auth_password(username='slowdive', password='error') - self.assert_(False) - except: - etype, evalue, etb = sys.exc_info() - self.assertEquals(SSHException, etype) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_8_multipart_auth(self): - """ - verify that multipart auth works. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - remain = self.tc.auth_password(username='paranoid', password='paranoid') - self.assertEquals(['publickey'], remain) - key = DSSKey.from_private_key_file('tests/test_dss.key') - remain = self.tc.auth_publickey(username='paranoid', key=key) - self.assertEquals([], remain) - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_9_interactive_auth(self): - """ - verify keyboard-interactive auth works. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - - def handler(title, instructions, prompts): - self.got_title = title - self.got_instructions = instructions - self.got_prompts = prompts - return ['cat'] - remain = self.tc.auth_interactive('commie', handler) - self.assertEquals(self.got_title, 'password') - self.assertEquals(self.got_prompts, [('Password', False)]) - self.assertEquals([], remain) - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + self.assertEquals('keepalive@lag.net', self.server._global_request) - def test_A_interactive_auth_fallback(self): - """ - verify that a password auth attempt will fallback to "interactive" - if password auth isn't supported but interactive is. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - remain = self.tc.auth_password('commie', 'cat') - self.assertEquals([], remain) - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_B_exec_command(self): + def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) - self.assert_(not chan.exec_command('no')) + try: + chan.exec_command('no') + self.assert_(False) + except SSHException, x: + pass chan = self.tc.open_session() - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') @@ -369,7 +240,7 @@ class TransportTest (unittest.TestCase): # now try it with combined stdout/stderr chan = self.tc.open_session() - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') @@ -381,26 +252,13 @@ class TransportTest (unittest.TestCase): self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) - def test_C_invoke_shell(self): + def test_7_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() chan = self.tc.open_session() - self.assert_(chan.invoke_shell()) + chan.invoke_shell() schan = self.ts.accept(1.0) chan.send('communist j. cat\n') f = schan.makefile() @@ -408,28 +266,28 @@ class TransportTest (unittest.TestCase): chan.close() self.assertEquals('', f.readline()) - def test_D_exit_status(self): + def test_8_channel_exception(self): + """ + verify that ChannelException is thrown for a bad open-channel request. + """ + self.setup_test_server() + try: + chan = self.tc.open_channel('bogus') + self.fail('expected exception') + except ChannelException, x: + self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) + + def test_9_exit_status(self): """ verify that get_exit_status() works. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan.send('Hello there.\n') + self.assert_(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() @@ -439,29 +297,22 @@ class TransportTest (unittest.TestCase): f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('', f.readline()) + count = 0 + while not chan.exit_status_ready(): + time.sleep(0.1) + count += 1 + if count > 50: + raise Exception("timeout") self.assertEquals(23, chan.recv_exit_status()) chan.close() - def test_E_select(self): + def test_A_select(self): """ verify that select() on a channel works. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() chan = self.tc.open_session() - self.assert_(chan.invoke_shell()) + chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready @@ -503,28 +354,21 @@ class TransportTest (unittest.TestCase): self.assertEquals([], e) self.assertEquals('', chan.recv(16)) + # make sure the pipe is still open for now... + p = chan._pipe + self.assertEquals(False, p._closed) chan.close() + # ...and now is closed. + self.assertEquals(True, p._closed) - def test_F_renegotiate(self): + def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 - chan = self.tc.open_session() - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan = self.ts.accept(1.0) self.assertEquals(self.tc.H, self.tc.session_id) @@ -541,26 +385,15 @@ class TransportTest (unittest.TestCase): schan.close() - def test_G_compression(self): + def test_C_compression(self): """ verify that zlib compression is basically working. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - self.ts.get_security_options().compression = ('zlib',) - self.tc.get_security_options().compression = ('zlib',) - event = threading.Event() - server = NullServer() - self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + def force_compression(o): + o.compression = ('zlib',) + self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes @@ -568,6 +401,326 @@ class TransportTest (unittest.TestCase): bytes2 = self.tc.packetizer._Packetizer__sent_bytes # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assert_(bytes2 - bytes < 1024) + self.assertEquals(52, bytes2 - bytes) chan.close() schan.close() + + def test_D_x11(self): + """ + verify that an x11 port can be requested and opened. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + requested = [] + def handler(c, (addr, port)): + requested.append((addr, port)) + self.tc._queue_incoming_channel(c) + + self.assertEquals(None, getattr(self.server, '_x11_screen_number', None)) + cookie = chan.request_x11(0, single_connection=True, handler=handler) + self.assertEquals(0, self.server._x11_screen_number) + self.assertEquals('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) + self.assertEquals(cookie, self.server._x11_auth_cookie) + self.assertEquals(True, self.server._x11_single_connection) + + x11_server = self.ts.open_x11_channel(('localhost', 6093)) + x11_client = self.tc.accept() + self.assertEquals('localhost', requested[0][0]) + self.assertEquals(6093, requested[0][1]) + + x11_server.send('hello') + self.assertEquals('hello', x11_client.recv(5)) + + x11_server.close() + x11_client.close() + chan.close() + schan.close() + + def test_E_reverse_port_forwarding(self): + """ + verify that a client can ask the server to open a reverse port for + forwarding. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + requested = [] + def handler(c, (origin_addr, origin_port), (server_addr, server_port)): + requested.append((origin_addr, origin_port)) + requested.append((server_addr, server_port)) + self.tc._queue_incoming_channel(c) + + port = self.tc.request_port_forward('127.0.0.1', 0, handler) + self.assertEquals(port, self.server._listen.getsockname()[1]) + + cs = socket.socket() + cs.connect(('127.0.0.1', port)) + ss, _ = self.server._listen.accept() + sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) + cch = self.tc.accept() + + sch.send('hello') + self.assertEquals('hello', cch.recv(5)) + sch.close() + cch.close() + ss.close() + cs.close() + + # now cancel it. + self.tc.cancel_port_forward('127.0.0.1', port) + self.assertTrue(self.server._listen is None) + + def test_F_port_forwarding(self): + """ + verify that a client can forward new connections from a locally- + forwarded port. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + # open a port on the "server" that the client will ask to forward to. + greeting_server = socket.socket() + greeting_server.bind(('127.0.0.1', 0)) + greeting_server.listen(1) + greeting_port = greeting_server.getsockname()[1] + + cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) + sch = self.ts.accept(1.0) + cch = socket.socket() + cch.connect(self.server._tcpip_dest) + + ss, _ = greeting_server.accept() + ss.send('Hello!\n') + ss.close() + sch.send(cch.recv(8192)) + sch.close() + + self.assertEquals('Hello!\n', cs.recv(7)) + cs.close() + + def test_G_stderr_select(self): + """ + verify that select() on a channel works even if only stderr is + receiving data. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.invoke_shell() + schan = self.ts.accept(1.0) + + # nothing should be ready + r, w, e = select.select([chan], [], [], 0.1) + self.assertEquals([], r) + self.assertEquals([], w) + self.assertEquals([], e) + + schan.send_stderr('hello\n') + + # something should be ready now (give it 1 second to appear) + for i in range(10): + r, w, e = select.select([chan], [], [], 0.1) + if chan in r: + break + time.sleep(0.1) + self.assertEquals([chan], r) + self.assertEquals([], w) + self.assertEquals([], e) + + self.assertEquals('hello\n', chan.recv_stderr(6)) + + # and, should be dead again now + r, w, e = select.select([chan], [], [], 0.1) + self.assertEquals([], r) + self.assertEquals([], w) + self.assertEquals([], e) + + schan.close() + chan.close() + + def test_H_send_ready(self): + """ + verify that send_ready() indicates when a send would not block. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.invoke_shell() + schan = self.ts.accept(1.0) + + self.assertEquals(chan.send_ready(), True) + total = 0 + K = '*' * 1024 + while total < 1024 * 1024: + chan.send(K) + total += len(K) + if not chan.send_ready(): + break + self.assert_(total < 1024 * 1024) + + schan.close() + chan.close() + self.assertEquals(chan.send_ready(), True) + + def test_I_rekey_deadlock(self): + """ + Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent + + Note: When this test fails, it may leak threads. + """ + + # Test for an obscure deadlocking bug that can occur if we receive + # certain messages while initiating a key exchange. + # + # The deadlock occurs as follows: + # + # In the main thread: + # 1. The user's program calls Channel.send(), which sends + # MSG_CHANNEL_DATA to the remote host. + # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and + # sets the __need_rekey flag. + # + # In the Transport thread: + # 3. Packetizer notices that the __need_rekey flag is set, and raises + # NeedRekeyException. + # 4. In response to NeedRekeyException, the transport thread sends + # MSG_KEXINIT to the remote host. + # + # On the remote host (using any SSH implementation): + # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. + # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. + # + # In the main thread: + # 7. The user's program calls Channel.send(). + # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). + # 9. Transport._send_user_message waits for Transport.clear_to_send + # to be set (i.e., it waits for re-keying to complete). + # Channel.lock is still held. + # + # In the Transport thread: + # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust + # is called to handle it. + # 11. Channel._window_adjust tries to acquire Channel.lock, but it + # blocks because the lock is already held by the main thread. + # + # The result is that the Transport thread never processes the remote + # host's MSG_KEXINIT packet, because it becomes deadlocked while + # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. + + # We set up two separate threads for sending and receiving packets, + # while the main thread acts as a watchdog timer. If the timer + # expires, a deadlock is assumed. + + class SendThread(threading.Thread): + def __init__(self, chan, iterations, done_event): + threading.Thread.__init__(self, None, None, self.__class__.__name__) + self.setDaemon(True) + self.chan = chan + self.iterations = iterations + self.done_event = done_event + self.watchdog_event = threading.Event() + self.last = None + + def run(self): + try: + for i in xrange(1, 1+self.iterations): + if self.done_event.isSet(): + break + self.watchdog_event.set() + #print i, "SEND" + self.chan.send("x" * 2048) + finally: + self.done_event.set() + self.watchdog_event.set() + + class ReceiveThread(threading.Thread): + def __init__(self, chan, done_event): + threading.Thread.__init__(self, None, None, self.__class__.__name__) + self.setDaemon(True) + self.chan = chan + self.done_event = done_event + self.watchdog_event = threading.Event() + + def run(self): + try: + while not self.done_event.isSet(): + if self.chan.recv_ready(): + chan.recv(65536) + self.watchdog_event.set() + else: + if random.randint(0, 1): + time.sleep(random.randint(0, 500) / 1000.0) + finally: + self.done_event.set() + self.watchdog_event.set() + + self.setup_test_server() + self.ts.packetizer.REKEY_BYTES = 2048 + + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + # Monkey patch the client's Transport._handler_table so that the client + # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial + # MSG_KEXINIT. This is used to simulate the effect of network latency + # on a real MSG_CHANNEL_WINDOW_ADJUST message. + self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary + _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] + def _negotiate_keys_wrapper(self, m): + if self.local_kex_init is None: # Remote side sent KEXINIT + # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it + # before responding to the incoming MSG_KEXINIT. + m2 = Message() + m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) + m2.add_int(chan.remote_chanid) + m2.add_int(1) # bytes to add + self._send_message(m2) + return _negotiate_keys(self, m) + self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper + + # Parameters for the test + iterations = 500 # The deadlock does not happen every time, but it + # should after many iterations. + timeout = 5 + + # This event is set when the test is completed + done_event = threading.Event() + + # Start the sending thread + st = SendThread(schan, iterations, done_event) + st.start() + + # Start the receiving thread + rt = ReceiveThread(chan, done_event) + rt.start() + + # Act as a watchdog timer, checking + deadlocked = False + while not deadlocked and not done_event.isSet(): + for event in (st.watchdog_event, rt.watchdog_event): + event.wait(timeout) + if done_event.isSet(): + break + if not event.isSet(): + deadlocked = True + break + event.clear() + + # Tell the threads to stop (if they haven't already stopped). Note + # that if one or more threads are deadlocked, they might hang around + # forever (until the process exits). + done_event.set() + + # Assertion: We must not have detected a timeout. + self.assertFalse(deadlocked) + + # Close the channels + schan.close() + chan.close() |