aboutsummaryrefslogtreecommitdiff
path: root/tests/test_transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r--tests/test_transport.py723
1 files changed, 438 insertions, 285 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py
index 5fcc786..4b52c4f 100644
--- a/tests/test_transport.py
+++ b/tests/test_transport.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -20,12 +20,21 @@
Some unit tests for the ssh2 protocol in Transport.
"""
-import sys, time, threading, unittest
+from binascii import hexlify, unhexlify
import select
+import socket
+import sys
+import time
+import threading
+import unittest
+import random
+
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
- SSHException, BadAuthenticationType, InteractiveQuery, util
+ SSHException, BadAuthenticationType, InteractiveQuery, ChannelException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
-from paramiko import OPEN_SUCCEEDED
+from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST
+from paramiko.message import Message
from loop import LoopSocket
@@ -37,50 +46,16 @@ class NullServer (ServerInterface):
def get_allowed_auths(self, username):
if username == 'slowdive':
return 'publickey,password'
- if username == 'paranoid':
- if not self.paranoid_did_password and not self.paranoid_did_public_key:
- return 'publickey,password'
- elif self.paranoid_did_password:
- return 'publickey'
- else:
- return 'password'
- if username == 'commie':
- return 'keyboard-interactive'
return 'publickey'
def check_auth_password(self, username, password):
if (username == 'slowdive') and (password == 'pygmalion'):
return AUTH_SUCCESSFUL
- if (username == 'paranoid') and (password == 'paranoid'):
- # 2-part auth (even openssh doesn't support this)
- self.paranoid_did_password = True
- if self.paranoid_did_public_key:
- return AUTH_SUCCESSFUL
- return AUTH_PARTIALLY_SUCCESSFUL
- return AUTH_FAILED
-
- def check_auth_publickey(self, username, key):
- if (username == 'paranoid') and (key == self.paranoid_key):
- # 2-part auth
- self.paranoid_did_public_key = True
- if self.paranoid_did_password:
- return AUTH_SUCCESSFUL
- return AUTH_PARTIALLY_SUCCESSFUL
- return AUTH_FAILED
-
- def check_auth_interactive(self, username, submethods):
- if username == 'commie':
- self.username = username
- return InteractiveQuery('password', 'Please enter a password.', ('Password', False))
- return AUTH_FAILED
-
- def check_auth_interactive_response(self, responses):
- if self.username == 'commie':
- if (len(responses) == 1) and (responses[0] == 'cat'):
- return AUTH_SUCCESSFUL
return AUTH_FAILED
def check_channel_request(self, kind, chanid):
+ if kind == 'bogus':
+ return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
return OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
@@ -94,10 +69,34 @@ class NullServer (ServerInterface):
def check_global_request(self, kind, msg):
self._global_request = kind
return False
+
+ def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number):
+ self._x11_single_connection = single_connection
+ self._x11_auth_protocol = auth_protocol
+ self._x11_auth_cookie = auth_cookie
+ self._x11_screen_number = screen_number
+ return True
+
+ def check_port_forward_request(self, addr, port):
+ self._listen = socket.socket()
+ self._listen.bind(('127.0.0.1', 0))
+ self._listen.listen(1)
+ return self._listen.getsockname()[1]
+
+ def cancel_port_forward_request(self, addr, port):
+ self._listen.close()
+ self._listen = None
+
+ def check_channel_direct_tcpip_request(self, chanid, origin, destination):
+ self._tcpip_dest = destination
+ return OPEN_SUCCEEDED
class TransportTest (unittest.TestCase):
+ assertTrue = unittest.TestCase.failUnless # for Python 2.3 and below
+ assertFalse = unittest.TestCase.failIf # for Python 2.3 and below
+
def setUp(self):
self.socks = LoopSocket()
self.sockc = LoopSocket()
@@ -111,6 +110,26 @@ class TransportTest (unittest.TestCase):
self.socks.close()
self.sockc.close()
+ def setup_test_server(self, client_options=None, server_options=None):
+ host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
+ public_host_key = RSAKey(data=str(host_key))
+ self.ts.add_server_key(host_key)
+
+ if client_options is not None:
+ client_options(self.tc.get_security_options())
+ if server_options is not None:
+ server_options(self.ts.get_security_options())
+
+ event = threading.Event()
+ self.server = NullServer()
+ self.assert_(not event.isSet())
+ self.ts.start_server(event, self.server)
+ self.tc.connect(hostkey=public_host_key,
+ username='slowdive', password='pygmalion')
+ event.wait(1.0)
+ self.assert_(event.isSet())
+ self.assert_(self.ts.is_active())
+
def test_1_security_options(self):
o = self.tc.get_security_options()
self.assertEquals(type(o), SecurityOptions)
@@ -130,11 +149,11 @@ class TransportTest (unittest.TestCase):
def test_2_compute_key(self):
self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L
- self.tc.H = util.unhexify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3')
+ self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3')
self.tc.session_id = self.tc.H
key = self.tc._compute_key('C', 32)
self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
- util.hexify(key))
+ hexlify(key).upper())
def test_3_simple(self):
"""
@@ -168,193 +187,45 @@ class TransportTest (unittest.TestCase):
verify that the client can demand odd handshake settings, and can
renegotiate keys in mid-stream.
"""
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- options = self.tc.get_security_options()
- options.ciphers = ('aes256-cbc',)
- options.digests = ('hmac-md5-96',)
- self.tc.connect(hostkey=public_host_key,
- username='slowdive', password='pygmalion')
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
+ def force_algorithms(options):
+ options.ciphers = ('aes256-cbc',)
+ options.digests = ('hmac-md5-96',)
+ self.setup_test_server(client_options=force_algorithms)
self.assertEquals('aes256-cbc', self.tc.local_cipher)
self.assertEquals('aes256-cbc', self.tc.remote_cipher)
self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024)
- self.assert_(self.tc.renegotiate_keys())
+ self.tc.renegotiate_keys()
self.ts.send_ignore(1024)
def test_5_keepalive(self):
"""
verify that the keepalive will be sent.
"""
- self.tc.set_hexdump(True)
-
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- self.tc.connect(hostkey=public_host_key,
- username='slowdive', password='pygmalion')
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
-
- self.assertEquals(None, getattr(server, '_global_request', None))
+ self.setup_test_server()
+ self.assertEquals(None, getattr(self.server, '_global_request', None))
self.tc.set_keepalive(1)
time.sleep(2)
- self.assertEquals('keepalive@lag.net', server._global_request)
-
- def test_6_bad_auth_type(self):
- """
- verify that we get the right exception when an unsupported auth
- type is requested.
- """
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- try:
- self.tc.connect(hostkey=public_host_key,
- username='unknown', password='error')
- self.assert_(False)
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertEquals(BadAuthenticationType, etype)
- self.assertEquals(['publickey'], evalue.allowed_types)
-
- def test_7_bad_password(self):
- """
- verify that a bad password gets the right exception, and that a retry
- with the right password works.
- """
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- self.tc.ultra_debug = True
- self.tc.connect(hostkey=public_host_key)
- try:
- self.tc.auth_password(username='slowdive', password='error')
- self.assert_(False)
- except:
- etype, evalue, etb = sys.exc_info()
- self.assertEquals(SSHException, etype)
- self.tc.auth_password(username='slowdive', password='pygmalion')
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
-
- def test_8_multipart_auth(self):
- """
- verify that multipart auth works.
- """
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- self.tc.ultra_debug = True
- self.tc.connect(hostkey=public_host_key)
- remain = self.tc.auth_password(username='paranoid', password='paranoid')
- self.assertEquals(['publickey'], remain)
- key = DSSKey.from_private_key_file('tests/test_dss.key')
- remain = self.tc.auth_publickey(username='paranoid', key=key)
- self.assertEquals([], remain)
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
-
- def test_9_interactive_auth(self):
- """
- verify keyboard-interactive auth works.
- """
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- self.tc.ultra_debug = True
- self.tc.connect(hostkey=public_host_key)
-
- def handler(title, instructions, prompts):
- self.got_title = title
- self.got_instructions = instructions
- self.got_prompts = prompts
- return ['cat']
- remain = self.tc.auth_interactive('commie', handler)
- self.assertEquals(self.got_title, 'password')
- self.assertEquals(self.got_prompts, [('Password', False)])
- self.assertEquals([], remain)
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
+ self.assertEquals('keepalive@lag.net', self.server._global_request)
- def test_A_interactive_auth_fallback(self):
- """
- verify that a password auth attempt will fallback to "interactive"
- if password auth isn't supported but interactive is.
- """
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- self.tc.ultra_debug = True
- self.tc.connect(hostkey=public_host_key)
- remain = self.tc.auth_password('commie', 'cat')
- self.assertEquals([], remain)
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
-
- def test_B_exec_command(self):
+ def test_6_exec_command(self):
"""
verify that exec_command() does something reasonable.
"""
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- self.tc.ultra_debug = True
- self.tc.connect(hostkey=public_host_key)
- self.tc.auth_password(username='slowdive', password='pygmalion')
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
+ self.setup_test_server()
chan = self.tc.open_session()
schan = self.ts.accept(1.0)
- self.assert_(not chan.exec_command('no'))
+ try:
+ chan.exec_command('no')
+ self.assert_(False)
+ except SSHException, x:
+ pass
chan = self.tc.open_session()
- self.assert_(chan.exec_command('yes'))
+ chan.exec_command('yes')
schan = self.ts.accept(1.0)
schan.send('Hello there.\n')
schan.send_stderr('This is on stderr.\n')
@@ -369,7 +240,7 @@ class TransportTest (unittest.TestCase):
# now try it with combined stdout/stderr
chan = self.tc.open_session()
- self.assert_(chan.exec_command('yes'))
+ chan.exec_command('yes')
schan = self.ts.accept(1.0)
schan.send('Hello there.\n')
schan.send_stderr('This is on stderr.\n')
@@ -381,26 +252,13 @@ class TransportTest (unittest.TestCase):
self.assertEquals('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline())
- def test_C_invoke_shell(self):
+ def test_7_invoke_shell(self):
"""
verify that invoke_shell() does something reasonable.
"""
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- self.tc.ultra_debug = True
- self.tc.connect(hostkey=public_host_key)
- self.tc.auth_password(username='slowdive', password='pygmalion')
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
-
+ self.setup_test_server()
chan = self.tc.open_session()
- self.assert_(chan.invoke_shell())
+ chan.invoke_shell()
schan = self.ts.accept(1.0)
chan.send('communist j. cat\n')
f = schan.makefile()
@@ -408,28 +266,28 @@ class TransportTest (unittest.TestCase):
chan.close()
self.assertEquals('', f.readline())
- def test_D_exit_status(self):
+ def test_8_channel_exception(self):
+ """
+ verify that ChannelException is thrown for a bad open-channel request.
+ """
+ self.setup_test_server()
+ try:
+ chan = self.tc.open_channel('bogus')
+ self.fail('expected exception')
+ except ChannelException, x:
+ self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
+
+ def test_9_exit_status(self):
"""
verify that get_exit_status() works.
"""
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- self.tc.ultra_debug = True
- self.tc.connect(hostkey=public_host_key)
- self.tc.auth_password(username='slowdive', password='pygmalion')
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
+ self.setup_test_server()
chan = self.tc.open_session()
schan = self.ts.accept(1.0)
- self.assert_(chan.exec_command('yes'))
+ chan.exec_command('yes')
schan.send('Hello there.\n')
+ self.assert_(not chan.exit_status_ready())
# trigger an EOF
schan.shutdown_read()
schan.shutdown_write()
@@ -439,29 +297,22 @@ class TransportTest (unittest.TestCase):
f = chan.makefile()
self.assertEquals('Hello there.\n', f.readline())
self.assertEquals('', f.readline())
+ count = 0
+ while not chan.exit_status_ready():
+ time.sleep(0.1)
+ count += 1
+ if count > 50:
+ raise Exception("timeout")
self.assertEquals(23, chan.recv_exit_status())
chan.close()
- def test_E_select(self):
+ def test_A_select(self):
"""
verify that select() on a channel works.
"""
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.assert_(not event.isSet())
- self.ts.start_server(event, server)
- self.tc.ultra_debug = True
- self.tc.connect(hostkey=public_host_key)
- self.tc.auth_password(username='slowdive', password='pygmalion')
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
-
+ self.setup_test_server()
chan = self.tc.open_session()
- self.assert_(chan.invoke_shell())
+ chan.invoke_shell()
schan = self.ts.accept(1.0)
# nothing should be ready
@@ -503,28 +354,21 @@ class TransportTest (unittest.TestCase):
self.assertEquals([], e)
self.assertEquals('', chan.recv(16))
+ # make sure the pipe is still open for now...
+ p = chan._pipe
+ self.assertEquals(False, p._closed)
chan.close()
+ # ...and now is closed.
+ self.assertEquals(True, p._closed)
- def test_F_renegotiate(self):
+ def test_B_renegotiate(self):
"""
verify that a transport can correctly renegotiate mid-stream.
"""
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- event = threading.Event()
- server = NullServer()
- self.ts.start_server(event, server)
- self.tc.connect(hostkey=public_host_key,
- username='slowdive', password='pygmalion')
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
-
+ self.setup_test_server()
self.tc.packetizer.REKEY_BYTES = 16384
-
chan = self.tc.open_session()
- self.assert_(chan.exec_command('yes'))
+ chan.exec_command('yes')
schan = self.ts.accept(1.0)
self.assertEquals(self.tc.H, self.tc.session_id)
@@ -541,26 +385,15 @@ class TransportTest (unittest.TestCase):
schan.close()
- def test_G_compression(self):
+ def test_C_compression(self):
"""
verify that zlib compression is basically working.
"""
- host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
- public_host_key = RSAKey(data=str(host_key))
- self.ts.add_server_key(host_key)
- self.ts.get_security_options().compression = ('zlib',)
- self.tc.get_security_options().compression = ('zlib',)
- event = threading.Event()
- server = NullServer()
- self.ts.start_server(event, server)
- self.tc.connect(hostkey=public_host_key,
- username='slowdive', password='pygmalion')
- event.wait(1.0)
- self.assert_(event.isSet())
- self.assert_(self.ts.is_active())
-
+ def force_compression(o):
+ o.compression = ('zlib',)
+ self.setup_test_server(force_compression, force_compression)
chan = self.tc.open_session()
- self.assert_(chan.exec_command('yes'))
+ chan.exec_command('yes')
schan = self.ts.accept(1.0)
bytes = self.tc.packetizer._Packetizer__sent_bytes
@@ -568,6 +401,326 @@ class TransportTest (unittest.TestCase):
bytes2 = self.tc.packetizer._Packetizer__sent_bytes
# tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :)
self.assert_(bytes2 - bytes < 1024)
+ self.assertEquals(52, bytes2 - bytes)
chan.close()
schan.close()
+
+ def test_D_x11(self):
+ """
+ verify that an x11 port can be requested and opened.
+ """
+ self.setup_test_server()
+ chan = self.tc.open_session()
+ chan.exec_command('yes')
+ schan = self.ts.accept(1.0)
+
+ requested = []
+ def handler(c, (addr, port)):
+ requested.append((addr, port))
+ self.tc._queue_incoming_channel(c)
+
+ self.assertEquals(None, getattr(self.server, '_x11_screen_number', None))
+ cookie = chan.request_x11(0, single_connection=True, handler=handler)
+ self.assertEquals(0, self.server._x11_screen_number)
+ self.assertEquals('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
+ self.assertEquals(cookie, self.server._x11_auth_cookie)
+ self.assertEquals(True, self.server._x11_single_connection)
+
+ x11_server = self.ts.open_x11_channel(('localhost', 6093))
+ x11_client = self.tc.accept()
+ self.assertEquals('localhost', requested[0][0])
+ self.assertEquals(6093, requested[0][1])
+
+ x11_server.send('hello')
+ self.assertEquals('hello', x11_client.recv(5))
+
+ x11_server.close()
+ x11_client.close()
+ chan.close()
+ schan.close()
+
+ def test_E_reverse_port_forwarding(self):
+ """
+ verify that a client can ask the server to open a reverse port for
+ forwarding.
+ """
+ self.setup_test_server()
+ chan = self.tc.open_session()
+ chan.exec_command('yes')
+ schan = self.ts.accept(1.0)
+
+ requested = []
+ def handler(c, (origin_addr, origin_port), (server_addr, server_port)):
+ requested.append((origin_addr, origin_port))
+ requested.append((server_addr, server_port))
+ self.tc._queue_incoming_channel(c)
+
+ port = self.tc.request_port_forward('127.0.0.1', 0, handler)
+ self.assertEquals(port, self.server._listen.getsockname()[1])
+
+ cs = socket.socket()
+ cs.connect(('127.0.0.1', port))
+ ss, _ = self.server._listen.accept()
+ sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername())
+ cch = self.tc.accept()
+
+ sch.send('hello')
+ self.assertEquals('hello', cch.recv(5))
+ sch.close()
+ cch.close()
+ ss.close()
+ cs.close()
+
+ # now cancel it.
+ self.tc.cancel_port_forward('127.0.0.1', port)
+ self.assertTrue(self.server._listen is None)
+
+ def test_F_port_forwarding(self):
+ """
+ verify that a client can forward new connections from a locally-
+ forwarded port.
+ """
+ self.setup_test_server()
+ chan = self.tc.open_session()
+ chan.exec_command('yes')
+ schan = self.ts.accept(1.0)
+
+ # open a port on the "server" that the client will ask to forward to.
+ greeting_server = socket.socket()
+ greeting_server.bind(('127.0.0.1', 0))
+ greeting_server.listen(1)
+ greeting_port = greeting_server.getsockname()[1]
+
+ cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000))
+ sch = self.ts.accept(1.0)
+ cch = socket.socket()
+ cch.connect(self.server._tcpip_dest)
+
+ ss, _ = greeting_server.accept()
+ ss.send('Hello!\n')
+ ss.close()
+ sch.send(cch.recv(8192))
+ sch.close()
+
+ self.assertEquals('Hello!\n', cs.recv(7))
+ cs.close()
+
+ def test_G_stderr_select(self):
+ """
+ verify that select() on a channel works even if only stderr is
+ receiving data.
+ """
+ self.setup_test_server()
+ chan = self.tc.open_session()
+ chan.invoke_shell()
+ schan = self.ts.accept(1.0)
+
+ # nothing should be ready
+ r, w, e = select.select([chan], [], [], 0.1)
+ self.assertEquals([], r)
+ self.assertEquals([], w)
+ self.assertEquals([], e)
+
+ schan.send_stderr('hello\n')
+
+ # something should be ready now (give it 1 second to appear)
+ for i in range(10):
+ r, w, e = select.select([chan], [], [], 0.1)
+ if chan in r:
+ break
+ time.sleep(0.1)
+ self.assertEquals([chan], r)
+ self.assertEquals([], w)
+ self.assertEquals([], e)
+
+ self.assertEquals('hello\n', chan.recv_stderr(6))
+
+ # and, should be dead again now
+ r, w, e = select.select([chan], [], [], 0.1)
+ self.assertEquals([], r)
+ self.assertEquals([], w)
+ self.assertEquals([], e)
+
+ schan.close()
+ chan.close()
+
+ def test_H_send_ready(self):
+ """
+ verify that send_ready() indicates when a send would not block.
+ """
+ self.setup_test_server()
+ chan = self.tc.open_session()
+ chan.invoke_shell()
+ schan = self.ts.accept(1.0)
+
+ self.assertEquals(chan.send_ready(), True)
+ total = 0
+ K = '*' * 1024
+ while total < 1024 * 1024:
+ chan.send(K)
+ total += len(K)
+ if not chan.send_ready():
+ break
+ self.assert_(total < 1024 * 1024)
+
+ schan.close()
+ chan.close()
+ self.assertEquals(chan.send_ready(), True)
+
+ def test_I_rekey_deadlock(self):
+ """
+ Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent
+
+ Note: When this test fails, it may leak threads.
+ """
+
+ # Test for an obscure deadlocking bug that can occur if we receive
+ # certain messages while initiating a key exchange.
+ #
+ # The deadlock occurs as follows:
+ #
+ # In the main thread:
+ # 1. The user's program calls Channel.send(), which sends
+ # MSG_CHANNEL_DATA to the remote host.
+ # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and
+ # sets the __need_rekey flag.
+ #
+ # In the Transport thread:
+ # 3. Packetizer notices that the __need_rekey flag is set, and raises
+ # NeedRekeyException.
+ # 4. In response to NeedRekeyException, the transport thread sends
+ # MSG_KEXINIT to the remote host.
+ #
+ # On the remote host (using any SSH implementation):
+ # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent.
+ # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent.
+ #
+ # In the main thread:
+ # 7. The user's program calls Channel.send().
+ # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message().
+ # 9. Transport._send_user_message waits for Transport.clear_to_send
+ # to be set (i.e., it waits for re-keying to complete).
+ # Channel.lock is still held.
+ #
+ # In the Transport thread:
+ # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust
+ # is called to handle it.
+ # 11. Channel._window_adjust tries to acquire Channel.lock, but it
+ # blocks because the lock is already held by the main thread.
+ #
+ # The result is that the Transport thread never processes the remote
+ # host's MSG_KEXINIT packet, because it becomes deadlocked while
+ # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message.
+
+ # We set up two separate threads for sending and receiving packets,
+ # while the main thread acts as a watchdog timer. If the timer
+ # expires, a deadlock is assumed.
+
+ class SendThread(threading.Thread):
+ def __init__(self, chan, iterations, done_event):
+ threading.Thread.__init__(self, None, None, self.__class__.__name__)
+ self.setDaemon(True)
+ self.chan = chan
+ self.iterations = iterations
+ self.done_event = done_event
+ self.watchdog_event = threading.Event()
+ self.last = None
+
+ def run(self):
+ try:
+ for i in xrange(1, 1+self.iterations):
+ if self.done_event.isSet():
+ break
+ self.watchdog_event.set()
+ #print i, "SEND"
+ self.chan.send("x" * 2048)
+ finally:
+ self.done_event.set()
+ self.watchdog_event.set()
+
+ class ReceiveThread(threading.Thread):
+ def __init__(self, chan, done_event):
+ threading.Thread.__init__(self, None, None, self.__class__.__name__)
+ self.setDaemon(True)
+ self.chan = chan
+ self.done_event = done_event
+ self.watchdog_event = threading.Event()
+
+ def run(self):
+ try:
+ while not self.done_event.isSet():
+ if self.chan.recv_ready():
+ chan.recv(65536)
+ self.watchdog_event.set()
+ else:
+ if random.randint(0, 1):
+ time.sleep(random.randint(0, 500) / 1000.0)
+ finally:
+ self.done_event.set()
+ self.watchdog_event.set()
+
+ self.setup_test_server()
+ self.ts.packetizer.REKEY_BYTES = 2048
+
+ chan = self.tc.open_session()
+ chan.exec_command('yes')
+ schan = self.ts.accept(1.0)
+
+ # Monkey patch the client's Transport._handler_table so that the client
+ # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial
+ # MSG_KEXINIT. This is used to simulate the effect of network latency
+ # on a real MSG_CHANNEL_WINDOW_ADJUST message.
+ self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary
+ _negotiate_keys = self.tc._handler_table[MSG_KEXINIT]
+ def _negotiate_keys_wrapper(self, m):
+ if self.local_kex_init is None: # Remote side sent KEXINIT
+ # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
+ # before responding to the incoming MSG_KEXINIT.
+ m2 = Message()
+ m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
+ m2.add_int(chan.remote_chanid)
+ m2.add_int(1) # bytes to add
+ self._send_message(m2)
+ return _negotiate_keys(self, m)
+ self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper
+
+ # Parameters for the test
+ iterations = 500 # The deadlock does not happen every time, but it
+ # should after many iterations.
+ timeout = 5
+
+ # This event is set when the test is completed
+ done_event = threading.Event()
+
+ # Start the sending thread
+ st = SendThread(schan, iterations, done_event)
+ st.start()
+
+ # Start the receiving thread
+ rt = ReceiveThread(chan, done_event)
+ rt.start()
+
+ # Act as a watchdog timer, checking
+ deadlocked = False
+ while not deadlocked and not done_event.isSet():
+ for event in (st.watchdog_event, rt.watchdog_event):
+ event.wait(timeout)
+ if done_event.isSet():
+ break
+ if not event.isSet():
+ deadlocked = True
+ break
+ event.clear()
+
+ # Tell the threads to stop (if they haven't already stopped). Note
+ # that if one or more threads are deadlocked, they might hang around
+ # forever (until the process exits).
+ done_event.set()
+
+ # Assertion: We must not have detected a timeout.
+ self.assertFalse(deadlocked)
+
+ # Close the channels
+ schan.close()
+ chan.close()