From ed280d5ac360e2af796e9bd973d7b4df89f0c449 Mon Sep 17 00:00:00 2001 From: "Jeremy T. Bouse" Date: Fri, 27 Nov 2009 16:20:12 -0500 Subject: Imported Upstream version 1.7.4 --- tests/loop.py | 4 +- tests/stub_sftp.py | 32 +- tests/test_auth.py | 231 ++++++++++++++ tests/test_buffered_pipe.py | 95 ++++++ tests/test_client.py | 214 +++++++++++++ tests/test_file.py | 4 +- tests/test_hostkeys.py | 117 +++++++ tests/test_kex.py | 101 +++++-- tests/test_message.py | 4 +- tests/test_packetizer.py | 2 +- tests/test_pkey.py | 61 +++- tests/test_sftp.py | 339 +++++++++------------ tests/test_sftp_big.py | 385 +++++++++++++++++++++++ tests/test_transport.py | 723 +++++++++++++++++++++++++++----------------- tests/test_util.py | 96 +++++- 15 files changed, 1866 insertions(+), 542 deletions(-) create mode 100644 tests/test_auth.py create mode 100644 tests/test_buffered_pipe.py create mode 100644 tests/test_client.py mode change 100644 => 100755 tests/test_file.py create mode 100644 tests/test_hostkeys.py mode change 100644 => 100755 tests/test_sftp.py create mode 100644 tests/test_sftp_big.py (limited to 'tests') diff --git a/tests/loop.py b/tests/loop.py index ad5f7ca..fb6ffae 100644 --- a/tests/loop.py +++ b/tests/loop.py @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # diff --git a/tests/stub_sftp.py b/tests/stub_sftp.py index 4b8b9c3..ac292ff 100644 --- a/tests/stub_sftp.py +++ b/tests/stub_sftp.py @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # @@ -48,6 +46,7 @@ class StubSFTPHandle (SFTPHandle): # use the stored filename try: SFTPServer.set_file_attr(self.filename, attr) + return SFTP_OK except OSError, e: return SFTPServer.convert_errno(e.errno) @@ -90,23 +89,38 @@ class StubSFTPServer (SFTPServerInterface): def open(self, path, flags, attr): path = self._realpath(path) try: - fd = os.open(path, flags) + binary_flag = getattr(os, 'O_BINARY', 0) + flags |= binary_flag + mode = getattr(attr, 'st_mode', None) + if mode is not None: + fd = os.open(path, flags, mode) + else: + # os.open() defaults to 0777 which is + # an odd default mode for files + fd = os.open(path, flags, 0666) except OSError, e: return SFTPServer.convert_errno(e.errno) if (flags & os.O_CREAT) and (attr is not None): + attr._flags &= ~attr.FLAG_PERMISSIONS SFTPServer.set_file_attr(path, attr) if flags & os.O_WRONLY: - fstr = 'w' + if flags & os.O_APPEND: + fstr = 'ab' + else: + fstr = 'wb' elif flags & os.O_RDWR: - fstr = 'r+' + if flags & os.O_APPEND: + fstr = 'a+b' + else: + fstr = 'r+b' else: # O_RDONLY (== 0) - fstr = 'r' + fstr = 'rb' try: f = os.fdopen(fd, fstr) except OSError, e: return SFTPServer.convert_errno(e.errno) - fobj = StubSFTPHandle() + fobj = StubSFTPHandle(flags) fobj.filename = path fobj.readfile = f fobj.writefile = f @@ -171,7 +185,7 @@ class StubSFTPServer (SFTPServerInterface): target_path = '' try: os.symlink(target_path, path) - except: + except OSError, e: return SFTPServer.convert_errno(e.errno) return SFTP_OK diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..fadd8ca --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,231 @@ +# Copyright (C) 2008 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Some unit tests for authenticating over a Transport. +""" + +import sys +import threading +import unittest + +from paramiko import Transport, ServerInterface, RSAKey, DSSKey, \ + SSHException, BadAuthenticationType, InteractiveQuery, ChannelException, \ + AuthenticationException +from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL +from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED +from loop import LoopSocket + + +class NullServer (ServerInterface): + paranoid_did_password = False + paranoid_did_public_key = False + paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key') + + def get_allowed_auths(self, username): + if username == 'slowdive': + return 'publickey,password' + if username == 'paranoid': + if not self.paranoid_did_password and not self.paranoid_did_public_key: + return 'publickey,password' + elif self.paranoid_did_password: + return 'publickey' + else: + return 'password' + if username == 'commie': + return 'keyboard-interactive' + if username == 'utf8': + return 'password' + if username == 'non-utf8': + return 'password' + return 'publickey' + + def check_auth_password(self, username, password): + if (username == 'slowdive') and (password == 'pygmalion'): + return AUTH_SUCCESSFUL + if (username == 'paranoid') and (password == 'paranoid'): + # 2-part auth (even openssh doesn't support this) + self.paranoid_did_password = True + if self.paranoid_did_public_key: + return AUTH_SUCCESSFUL + return AUTH_PARTIALLY_SUCCESSFUL + if (username == 'utf8') and (password == u'\u2022'): + return AUTH_SUCCESSFUL + if (username == 'non-utf8') and (password == '\xff'): + return AUTH_SUCCESSFUL + if username == 'bad-server': + raise Exception("Ack!") + return AUTH_FAILED + + def check_auth_publickey(self, username, key): + if (username == 'paranoid') and (key == self.paranoid_key): + # 2-part auth + self.paranoid_did_public_key = True + if self.paranoid_did_password: + return AUTH_SUCCESSFUL + return AUTH_PARTIALLY_SUCCESSFUL + return AUTH_FAILED + + def check_auth_interactive(self, username, submethods): + if username == 'commie': + self.username = username + return InteractiveQuery('password', 'Please enter a password.', ('Password', False)) + return AUTH_FAILED + + def check_auth_interactive_response(self, responses): + if self.username == 'commie': + if (len(responses) == 1) and (responses[0] == 'cat'): + return AUTH_SUCCESSFUL + return AUTH_FAILED + + +class AuthTest (unittest.TestCase): + + def setUp(self): + self.socks = LoopSocket() + self.sockc = LoopSocket() + self.sockc.link(self.socks) + self.tc = Transport(self.sockc) + self.ts = Transport(self.socks) + + def tearDown(self): + self.tc.close() + self.ts.close() + self.socks.close() + self.sockc.close() + + def start_server(self): + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + self.public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + self.event = threading.Event() + self.server = NullServer() + self.assert_(not self.event.isSet()) + self.ts.start_server(self.event, self.server) + + def verify_finished(self): + self.event.wait(1.0) + self.assert_(self.event.isSet()) + self.assert_(self.ts.is_active()) + + def test_1_bad_auth_type(self): + """ + verify that we get the right exception when an unsupported auth + type is requested. + """ + self.start_server() + try: + self.tc.connect(hostkey=self.public_host_key, + username='unknown', password='error') + self.assert_(False) + except: + etype, evalue, etb = sys.exc_info() + self.assertEquals(BadAuthenticationType, etype) + self.assertEquals(['publickey'], evalue.allowed_types) + + def test_2_bad_password(self): + """ + verify that a bad password gets the right exception, and that a retry + with the right password works. + """ + self.start_server() + self.tc.connect(hostkey=self.public_host_key) + try: + self.tc.auth_password(username='slowdive', password='error') + self.assert_(False) + except: + etype, evalue, etb = sys.exc_info() + self.assert_(issubclass(etype, AuthenticationException)) + self.tc.auth_password(username='slowdive', password='pygmalion') + self.verify_finished() + + def test_3_multipart_auth(self): + """ + verify that multipart auth works. + """ + self.start_server() + self.tc.connect(hostkey=self.public_host_key) + remain = self.tc.auth_password(username='paranoid', password='paranoid') + self.assertEquals(['publickey'], remain) + key = DSSKey.from_private_key_file('tests/test_dss.key') + remain = self.tc.auth_publickey(username='paranoid', key=key) + self.assertEquals([], remain) + self.verify_finished() + + def test_4_interactive_auth(self): + """ + verify keyboard-interactive auth works. + """ + self.start_server() + self.tc.connect(hostkey=self.public_host_key) + + def handler(title, instructions, prompts): + self.got_title = title + self.got_instructions = instructions + self.got_prompts = prompts + return ['cat'] + remain = self.tc.auth_interactive('commie', handler) + self.assertEquals(self.got_title, 'password') + self.assertEquals(self.got_prompts, [('Password', False)]) + self.assertEquals([], remain) + self.verify_finished() + + def test_5_interactive_auth_fallback(self): + """ + verify that a password auth attempt will fallback to "interactive" + if password auth isn't supported but interactive is. + """ + self.start_server() + self.tc.connect(hostkey=self.public_host_key) + remain = self.tc.auth_password('commie', 'cat') + self.assertEquals([], remain) + self.verify_finished() + + def test_6_auth_utf8(self): + """ + verify that utf-8 encoding happens in authentication. + """ + self.start_server() + self.tc.connect(hostkey=self.public_host_key) + remain = self.tc.auth_password('utf8', u'\u2022') + self.assertEquals([], remain) + self.verify_finished() + + def test_7_auth_non_utf8(self): + """ + verify that non-utf-8 encoded passwords can be used for broken + servers. + """ + self.start_server() + self.tc.connect(hostkey=self.public_host_key) + remain = self.tc.auth_password('non-utf8', '\xff') + self.assertEquals([], remain) + self.verify_finished() + + def test_8_auth_gets_disconnected(self): + """ + verify that we catch a server disconnecting during auth, and report + it as an auth failure. + """ + self.start_server() + self.tc.connect(hostkey=self.public_host_key) + try: + remain = self.tc.auth_password('bad-server', 'hello') + except: + etype, evalue, etb = sys.exc_info() + self.assert_(issubclass(etype, AuthenticationException)) diff --git a/tests/test_buffered_pipe.py b/tests/test_buffered_pipe.py new file mode 100644 index 0000000..f96edb8 --- /dev/null +++ b/tests/test_buffered_pipe.py @@ -0,0 +1,95 @@ +# Copyright (C) 2006-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Some unit tests for BufferedPipe. +""" + +import threading +import time +import unittest +from paramiko.buffered_pipe import BufferedPipe, PipeTimeout +from paramiko import pipe + + +def delay_thread(pipe): + pipe.feed('a') + time.sleep(0.5) + pipe.feed('b') + pipe.close() + + +def close_thread(pipe): + time.sleep(0.2) + pipe.close() + + +class BufferedPipeTest (unittest.TestCase): + + assertTrue = unittest.TestCase.failUnless # for Python 2.3 and below + assertFalse = unittest.TestCase.failIf # for Python 2.3 and below + + def test_1_buffered_pipe(self): + p = BufferedPipe() + self.assert_(not p.read_ready()) + p.feed('hello.') + self.assert_(p.read_ready()) + data = p.read(6) + self.assertEquals('hello.', data) + + p.feed('plus/minus') + self.assertEquals('plu', p.read(3)) + self.assertEquals('s/m', p.read(3)) + self.assertEquals('inus', p.read(4)) + + p.close() + self.assert_(not p.read_ready()) + self.assertEquals('', p.read(1)) + + def test_2_delay(self): + p = BufferedPipe() + self.assert_(not p.read_ready()) + threading.Thread(target=delay_thread, args=(p,)).start() + self.assertEquals('a', p.read(1, 0.1)) + try: + p.read(1, 0.1) + self.assert_(False) + except PipeTimeout: + pass + self.assertEquals('b', p.read(1, 1.0)) + self.assertEquals('', p.read(1)) + + def test_3_close_while_reading(self): + p = BufferedPipe() + threading.Thread(target=close_thread, args=(p,)).start() + data = p.read(1, 1.0) + self.assertEquals('', data) + + def test_4_or_pipe(self): + p = pipe.make_pipe() + p1, p2 = pipe.make_or_pipe(p) + self.assertFalse(p._set) + p1.set() + self.assertTrue(p._set) + p2.set() + self.assertTrue(p._set) + p1.clear() + self.assertTrue(p._set) + p2.clear() + self.assertFalse(p._set) + diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..59cd67c --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,214 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Some unit tests for SSHClient. +""" + +import socket +import threading +import time +import unittest +import weakref +from binascii import hexlify + +import paramiko + + +class NullServer (paramiko.ServerInterface): + + def get_allowed_auths(self, username): + if username == 'slowdive': + return 'publickey,password' + return 'publickey' + + def check_auth_password(self, username, password): + if (username == 'slowdive') and (password == 'pygmalion'): + return paramiko.AUTH_SUCCESSFUL + return paramiko.AUTH_FAILED + + def check_auth_publickey(self, username, key): + if (key.get_name() == 'ssh-dss') and (hexlify(key.get_fingerprint()) == '4478f0b9a23cc5182009ff755bc1d26c'): + return paramiko.AUTH_SUCCESSFUL + return paramiko.AUTH_FAILED + + def check_channel_request(self, kind, chanid): + return paramiko.OPEN_SUCCEEDED + + def check_channel_exec_request(self, channel, command): + if command != 'yes': + return False + return True + + +class SSHClientTest (unittest.TestCase): + + def setUp(self): + self.sockl = socket.socket() + self.sockl.bind(('localhost', 0)) + self.sockl.listen(1) + self.addr, self.port = self.sockl.getsockname() + self.event = threading.Event() + thread = threading.Thread(target=self._run) + thread.start() + + def tearDown(self): + if hasattr(self, 'tc'): + self.tc.close() + self.ts.close() + self.socks.close() + self.sockl.close() + + def _run(self): + self.socks, addr = self.sockl.accept() + self.ts = paramiko.Transport(self.socks) + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + self.ts.add_server_key(host_key) + server = NullServer() + self.ts.start_server(self.event, server) + + + def test_1_client(self): + """ + verify that the SSHClient stuff works too. + """ + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = paramiko.RSAKey(data=str(host_key)) + + self.tc = paramiko.SSHClient() + self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key) + self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') + + self.event.wait(1.0) + self.assert_(self.event.isSet()) + self.assert_(self.ts.is_active()) + self.assertEquals('slowdive', self.ts.get_username()) + self.assertEquals(True, self.ts.is_authenticated()) + + stdin, stdout, stderr = self.tc.exec_command('yes') + schan = self.ts.accept(1.0) + + schan.send('Hello there.\n') + schan.send_stderr('This is on stderr.\n') + schan.close() + + self.assertEquals('Hello there.\n', stdout.readline()) + self.assertEquals('', stdout.readline()) + self.assertEquals('This is on stderr.\n', stderr.readline()) + self.assertEquals('', stderr.readline()) + + stdin.close() + stdout.close() + stderr.close() + + def test_2_client_dsa(self): + """ + verify that SSHClient works with a DSA key. + """ + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = paramiko.RSAKey(data=str(host_key)) + + self.tc = paramiko.SSHClient() + self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key) + self.tc.connect(self.addr, self.port, username='slowdive', key_filename='tests/test_dss.key') + + self.event.wait(1.0) + self.assert_(self.event.isSet()) + self.assert_(self.ts.is_active()) + self.assertEquals('slowdive', self.ts.get_username()) + self.assertEquals(True, self.ts.is_authenticated()) + + stdin, stdout, stderr = self.tc.exec_command('yes') + schan = self.ts.accept(1.0) + + schan.send('Hello there.\n') + schan.send_stderr('This is on stderr.\n') + schan.close() + + self.assertEquals('Hello there.\n', stdout.readline()) + self.assertEquals('', stdout.readline()) + self.assertEquals('This is on stderr.\n', stderr.readline()) + self.assertEquals('', stderr.readline()) + + stdin.close() + stdout.close() + stderr.close() + + def test_3_multiple_key_files(self): + """ + verify that SSHClient accepts and tries multiple key files. + """ + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = paramiko.RSAKey(data=str(host_key)) + + self.tc = paramiko.SSHClient() + self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key) + self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ]) + + self.event.wait(1.0) + self.assert_(self.event.isSet()) + self.assert_(self.ts.is_active()) + self.assertEquals('slowdive', self.ts.get_username()) + self.assertEquals(True, self.ts.is_authenticated()) + + def test_4_auto_add_policy(self): + """ + verify that SSHClient's AutoAddPolicy works. + """ + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = paramiko.RSAKey(data=str(host_key)) + + self.tc = paramiko.SSHClient() + self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.assertEquals(0, len(self.tc.get_host_keys())) + self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') + + self.event.wait(1.0) + self.assert_(self.event.isSet()) + self.assert_(self.ts.is_active()) + self.assertEquals('slowdive', self.ts.get_username()) + self.assertEquals(True, self.ts.is_authenticated()) + self.assertEquals(1, len(self.tc.get_host_keys())) + self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa']) + + def test_5_cleanup(self): + """ + verify that when an SSHClient is collected, its transport (and the + transport's packetizer) is closed. + """ + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = paramiko.RSAKey(data=str(host_key)) + + self.tc = paramiko.SSHClient() + self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.assertEquals(0, len(self.tc.get_host_keys())) + self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') + + self.event.wait(1.0) + self.assert_(self.event.isSet()) + self.assert_(self.ts.is_active()) + + p = weakref.ref(self.tc._transport.packetizer) + self.assert_(p() is not None) + del self.tc + # hrm, sometimes p isn't cleared right away. why is that? + st = time.time() + while (time.time() - st < 5.0) and (p() is not None): + time.sleep(0.1) + self.assert_(p() is None) + diff --git a/tests/test_file.py b/tests/test_file.py old mode 100644 new mode 100755 index 250821c..d66babf --- a/tests/test_file.py +++ b/tests/test_file.py @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py new file mode 100644 index 0000000..28521ba --- /dev/null +++ b/tests/test_hostkeys.py @@ -0,0 +1,117 @@ +# Copyright (C) 2006-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Some unit tests for HostKeys. +""" + +import base64 +from binascii import hexlify +import os +import unittest +import paramiko + + +test_hosts_file = """\ +secure.example.com ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA1PD6U2/TVxET6lkpKhOk5r\ +9q/kAYG6sP9f5zuUYP8i7FOFp/6ncCEbbtg/lB+A3iidyxoSWl+9jtoyyDOOVX4UIDV9G11Ml8om3\ +D+jrpI9cycZHqilK0HmxDeCuxbwyMuaCygU9gS2qoRvNLWZk70OpIKSSpBo0Wl3/XUmz9uhc= +happy.example.com ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31M\ +BGQ3GQ/Fc7SX6gkpXkwcZryoi4kNFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW\ +5ymME3bQ4J/k1IKxCtz/bAlAqFgKoc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M= +""" + +keyblob = """\ +AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31MBGQ3GQ/Fc7SX6gkpXkwcZryoi4k\ +NFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW5ymME3bQ4J/k1IKxCtz/bAlAqFgK\ +oc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M=""" + +keyblob_dss = """\ +AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/\ +h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF60\ +8EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIE\ +AkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+\ +Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\ +4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO\ +1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o\ +0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=""" + + +class HostKeysTest (unittest.TestCase): + + def setUp(self): + f = open('hostfile.temp', 'w') + f.write(test_hosts_file) + f.close() + + def tearDown(self): + os.unlink('hostfile.temp') + + def test_1_load(self): + hostdict = paramiko.HostKeys('hostfile.temp') + self.assertEquals(2, len(hostdict)) + self.assertEquals(1, len(hostdict.values()[0])) + self.assertEquals(1, len(hostdict.values()[1])) + fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() + self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) + + def test_2_add(self): + hostdict = paramiko.HostKeys('hostfile.temp') + hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c=' + key = paramiko.RSAKey(data=base64.decodestring(keyblob)) + hostdict.add(hh, 'ssh-rsa', key) + self.assertEquals(3, len(hostdict)) + x = hostdict['foo.example.com'] + fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() + self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) + self.assert_(hostdict.check('foo.example.com', key)) + + def test_3_dict(self): + hostdict = paramiko.HostKeys('hostfile.temp') + self.assert_('secure.example.com' in hostdict) + self.assert_('not.example.com' not in hostdict) + self.assert_(hostdict.has_key('secure.example.com')) + self.assert_(not hostdict.has_key('not.example.com')) + x = hostdict.get('secure.example.com', None) + self.assert_(x is not None) + fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() + self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) + i = 0 + for key in hostdict: + i += 1 + self.assertEquals(2, i) + + def test_4_dict_set(self): + hostdict = paramiko.HostKeys('hostfile.temp') + key = paramiko.RSAKey(data=base64.decodestring(keyblob)) + key_dss = paramiko.DSSKey(data=base64.decodestring(keyblob_dss)) + hostdict['secure.example.com'] = { + 'ssh-rsa': key, + 'ssh-dss': key_dss + } + hostdict['fake.example.com'] = {} + hostdict['fake.example.com']['ssh-rsa'] = key + + self.assertEquals(3, len(hostdict)) + self.assertEquals(2, len(hostdict.values()[0])) + self.assertEquals(1, len(hostdict.values()[1])) + self.assertEquals(1, len(hostdict.values()[2])) + fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() + self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) + fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper() + self.assertEquals('4478F0B9A23CC5182009FF755BC1D26C', fp) diff --git a/tests/test_kex.py b/tests/test_kex.py index 2680853..f304275 100644 --- a/tests/test_kex.py +++ b/tests/test_kex.py @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # @@ -22,6 +20,7 @@ Some unit tests for the key exchange protocols. """ +from binascii import hexlify import unittest import paramiko.util from paramiko.kex_group1 import KexGroup1 @@ -35,18 +34,21 @@ class FakeRandpool (object): def get_bytes(self, n): return chr(0xcc) * n + class FakeKey (object): def __str__(self): return 'fake-key' def sign_ssh_data(self, randpool, H): return 'fake-sig' + class FakeModulusPack (object): P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFFL G = 2 def get_modulus(self, min, ask, max): return self.G, self.P + class FakeTransport (object): randpool = FakeRandpool() local_version = 'SSH-2.0-paramiko_1.0' @@ -56,7 +58,7 @@ class FakeTransport (object): def _send_message(self, m): self._message = m - def _expect_packet(self, t): + def _expect_packet(self, *t): self._expect = t def _set_K_H(self, K, H): self._K = K @@ -89,8 +91,8 @@ class KexTest (unittest.TestCase): kex = KexGroup1(transport) kex.start_kex() x = '1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' - self.assertEquals(x, paramiko.util.hexify(str(transport._message))) - self.assertEquals(paramiko.kex_group1._MSG_KEXDH_REPLY, transport._expect) + self.assertEquals(x, hexlify(str(transport._message)).upper()) + self.assertEquals((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect) # fake "reply" msg = Message() @@ -101,7 +103,7 @@ class KexTest (unittest.TestCase): kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg) H = '03079780F3D3AD0B3C6DB30C8D21685F367A86D2' self.assertEquals(self.K, transport._K) - self.assertEquals(H, paramiko.util.hexify(transport._H)) + self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify) self.assert_(transport._activated) @@ -110,7 +112,7 @@ class KexTest (unittest.TestCase): transport.server_mode = True kex = KexGroup1(transport) kex.start_kex() - self.assertEquals(paramiko.kex_group1._MSG_KEXDH_INIT, transport._expect) + self.assertEquals((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect) msg = Message() msg.add_mpint(69) @@ -119,8 +121,8 @@ class KexTest (unittest.TestCase): H = 'B16BF34DD10945EDE84E9C1EF24A14BFDC843389' x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' self.assertEquals(self.K, transport._K) - self.assertEquals(H, paramiko.util.hexify(transport._H)) - self.assertEquals(x, paramiko.util.hexify(str(transport._message))) + self.assertEquals(H, hexlify(transport._H).upper()) + self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assert_(transport._activated) def test_3_gex_client(self): @@ -129,8 +131,8 @@ class KexTest (unittest.TestCase): kex = KexGex(transport) kex.start_kex() x = '22000004000000080000002000' - self.assertEquals(x, paramiko.util.hexify(str(transport._message))) - self.assertEquals(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, transport._expect) + self.assertEquals(x, hexlify(str(transport._message)).upper()) + self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) msg = Message() msg.add_mpint(FakeModulusPack.P) @@ -138,8 +140,8 @@ class KexTest (unittest.TestCase): msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' - self.assertEquals(x, paramiko.util.hexify(str(transport._message))) - self.assertEquals(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, transport._expect) + self.assertEquals(x, hexlify(str(transport._message)).upper()) + self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) msg = Message() msg.add_string('fake-host-key') @@ -149,16 +151,46 @@ class KexTest (unittest.TestCase): kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) H = 'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0' self.assertEquals(self.K, transport._K) - self.assertEquals(H, paramiko.util.hexify(transport._H)) + self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify) self.assert_(transport._activated) - def test_4_gex_server(self): + def test_4_gex_old_client(self): + transport = FakeTransport() + transport.server_mode = False + kex = KexGex(transport) + kex.start_kex(_test_old_style=True) + x = '1E00000800' + self.assertEquals(x, hexlify(str(transport._message)).upper()) + self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) + + msg = Message() + msg.add_mpint(FakeModulusPack.P) + msg.add_mpint(FakeModulusPack.G) + msg.rewind() + kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) + x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' + self.assertEquals(x, hexlify(str(transport._message)).upper()) + self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) + + msg = Message() + msg.add_string('fake-host-key') + msg.add_mpint(69) + msg.add_string('fake-sig') + msg.rewind() + kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) + H = '807F87B269EF7AC5EC7E75676808776A27D5864C' + self.assertEquals(self.K, transport._K) + self.assertEquals(H, hexlify(transport._H).upper()) + self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify) + self.assert_(transport._activated) + + def test_5_gex_server(self): transport = FakeTransport() transport.server_mode = True kex = KexGex(transport) kex.start_kex() - self.assertEquals(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, transport._expect) + self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) msg = Message() msg.add_int(1024) @@ -167,8 +199,8 @@ class KexTest (unittest.TestCase): msg.rewind() kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' - self.assertEquals(x, paramiko.util.hexify(str(transport._message))) - self.assertEquals(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, transport._expect) + self.assertEquals(x, hexlify(str(transport._message)).upper()) + self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) msg = Message() msg.add_mpint(12345) @@ -178,6 +210,33 @@ class KexTest (unittest.TestCase): H = 'CE754197C21BF3452863B4F44D0B3951F12516EF' x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' self.assertEquals(K, transport._K) - self.assertEquals(H, paramiko.util.hexify(transport._H)) - self.assertEquals(x, paramiko.util.hexify(str(transport._message))) + self.assertEquals(H, hexlify(transport._H).upper()) + self.assertEquals(x, hexlify(str(transport._message)).upper()) + self.assert_(transport._activated) + + def test_6_gex_server_with_old_client(self): + transport = FakeTransport() + transport.server_mode = True + kex = KexGex(transport) + kex.start_kex() + self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) + + msg = Message() + msg.add_int(2048) + msg.rewind() + kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) + x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' + self.assertEquals(x, hexlify(str(transport._message)).upper()) + self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) + + msg = Message() + msg.add_mpint(12345) + msg.rewind() + kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) + K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L + H = 'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B' + x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' + self.assertEquals(K, transport._K) + self.assertEquals(H, hexlify(transport._H).upper()) + self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assert_(transport._activated) diff --git a/tests/test_message.py b/tests/test_message.py index 441e3ce..e930f71 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # diff --git a/tests/test_packetizer.py b/tests/test_packetizer.py index 8c992bd..cb6248f 100644 --- a/tests/test_packetizer.py +++ b/tests/test_packetizer.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # diff --git a/tests/test_pkey.py b/tests/test_pkey.py index e56edb1..e591ab1 100644 --- a/tests/test_pkey.py +++ b/tests/test_pkey.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # @@ -20,6 +20,8 @@ Some unit tests for public/private key objects. """ +from binascii import hexlify, unhexlify +import StringIO import unittest from paramiko import RSAKey, DSSKey, Message, util, randpool @@ -30,6 +32,39 @@ FINGER_RSA = '1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5' FINGER_DSS = '1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c' SIGNED_RSA = '20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8' +RSA_PRIVATE_OUT = """\ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKCAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAM +s6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZ +v3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4cCASMC +ggCAEiI6plhqipt4P05L3PYr0pHZq2VPEbE4k9eI/gRKo/c1VJxY3DJnc1cenKsk +trQRtW3OxCEufqsX5PNec6VyKkW+Ox6beJjMKm4KF8ZDpKi9Nw6MdX3P6Gele9D9 ++ieyhVFljrnAqcXsgChTBOYlL2imqCs3qRGAJ3cMBIAx3VsCQQD3pIFVYW398kE0 +n0e1icEpkbDRV4c5iZVhu8xKy2yyfy6f6lClSb2+Ub9uns7F3+b5v0pYSHbE9+/r +OpRq83AfAkEA2rMZlr8SnMXgnyka2LuggA9QgMYy18hyao1dUxySubNDa9N+q2QR +mwDisTUgRFHKIlDHoQmzPbXAmYZX1YlDmQJBAPCRLS5epV0XOAc7pL762OaNhzHC +veAfQKgVhKBt105PqaKpGyQ5AXcNlWQlPeTK4GBTbMrKDPna6RBkyrEJvV8CQBK+ +5O+p+kfztCrmRCE0p1tvBuZ3Y3GU1ptrM+KNa6mEZN1bRV8l1Z+SXJLYqv6Kquz/ +nBUeFq2Em3rfoSDugiMCQDyG3cxD5dKX3IgkhLyBWls/FLDk4x/DQ+NUTu0F1Cu6 +JJye+5ARLkL0EweMXf0tmIYfWItDLsWB0fKg/56h0js= +-----END RSA PRIVATE KEY----- +""" + +DSS_PRIVATE_OUT = """\ +-----BEGIN DSA PRIVATE KEY----- +MIIBvgIBAAKCAIEA54GmA2d9HOv+3CYBBG7ZfBYCncIW2tWe6Dqzp+DCP+guNhtW +2MDLqmX+HQQoJbHat/Uh63I2xPFaueID0jod4OPrlfUXIOSDqDy28Kdo0Hxen9RS +G7Me4awwiKlHEHHD0sXrTwSplyPUTfK2S2hbkHk5yOuQSjPfEbsL6ukiNi8CFQDw +z4UnmsGiSNu5iqjn3uTzwUpshwKCAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25c +PzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq +1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+X +FDxlqZo8Y+wCggCARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lY +ukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+N +wacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgECFGI9 +QPSch9pT9XHqn+1rZ4bK+QGA +-----END DSA PRIVATE KEY----- +""" + class KeyTest (unittest.TestCase): @@ -42,23 +77,30 @@ class KeyTest (unittest.TestCase): def test_1_generate_key_bytes(self): from Crypto.Hash import MD5 key = util.generate_key_bytes(MD5, '\x01\x02\x03\x04', 'happy birthday', 30) - exp = util.unhexify('61E1F272F4C1C4561586BD322498C0E924672780F47BB37DDA7D54019E64') + exp = unhexlify('61E1F272F4C1C4561586BD322498C0E924672780F47BB37DDA7D54019E64') self.assertEquals(exp, key) def test_2_load_rsa(self): key = RSAKey.from_private_key_file('tests/test_rsa.key') self.assertEquals('ssh-rsa', key.get_name()) exp_rsa = FINGER_RSA.split()[1].replace(':', '') - my_rsa = util.hexify(key.get_fingerprint()).lower() + my_rsa = hexlify(key.get_fingerprint()) self.assertEquals(exp_rsa, my_rsa) self.assertEquals(PUB_RSA.split()[1], key.get_base64()) self.assertEquals(1024, key.get_bits()) + s = StringIO.StringIO() + key.write_private_key(s) + self.assertEquals(RSA_PRIVATE_OUT, s.getvalue()) + s.seek(0) + key2 = RSAKey.from_private_key(s) + self.assertEquals(key, key2) + def test_3_load_rsa_password(self): key = RSAKey.from_private_key_file('tests/test_rsa_password.key', 'television') self.assertEquals('ssh-rsa', key.get_name()) exp_rsa = FINGER_RSA.split()[1].replace(':', '') - my_rsa = util.hexify(key.get_fingerprint()).lower() + my_rsa = hexlify(key.get_fingerprint()) self.assertEquals(exp_rsa, my_rsa) self.assertEquals(PUB_RSA.split()[1], key.get_base64()) self.assertEquals(1024, key.get_bits()) @@ -67,16 +109,23 @@ class KeyTest (unittest.TestCase): key = DSSKey.from_private_key_file('tests/test_dss.key') self.assertEquals('ssh-dss', key.get_name()) exp_dss = FINGER_DSS.split()[1].replace(':', '') - my_dss = util.hexify(key.get_fingerprint()).lower() + my_dss = hexlify(key.get_fingerprint()) self.assertEquals(exp_dss, my_dss) self.assertEquals(PUB_DSS.split()[1], key.get_base64()) self.assertEquals(1024, key.get_bits()) + s = StringIO.StringIO() + key.write_private_key(s) + self.assertEquals(DSS_PRIVATE_OUT, s.getvalue()) + s.seek(0) + key2 = DSSKey.from_private_key(s) + self.assertEquals(key, key2) + def test_5_load_dss_password(self): key = DSSKey.from_private_key_file('tests/test_dss_password.key', 'television') self.assertEquals('ssh-dss', key.get_name()) exp_dss = FINGER_DSS.split()[1].replace(':', '') - my_dss = util.hexify(key.get_fingerprint()).lower() + my_dss = hexlify(key.get_fingerprint()) self.assertEquals(exp_dss, my_dss) self.assertEquals(PUB_DSS.split()[1], key.get_base64()) self.assertEquals(1024, key.get_bits()) diff --git a/tests/test_sftp.py b/tests/test_sftp.py old mode 100644 new mode 100755 index 993899a..edc0599 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # @@ -23,9 +23,11 @@ a real actual sftp server is contacted, and a new folder is created there to do test file operations in (so no existing files will be harmed). """ +from binascii import hexlify import logging import os import random +import struct import sys import threading import time @@ -69,6 +71,11 @@ tc = None g_big_file_test = True +def get_sftp(): + global sftp + return sftp + + class SFTPTest (unittest.TestCase): def init(hostname, username, keyfile, passwd): @@ -273,28 +280,85 @@ class SFTPTest (unittest.TestCase): def test_8_setstat(self): """ - verify that the setstat functions (chown, chmod, utime) work. + verify that the setstat functions (chown, chmod, utime, truncate) work. """ f = sftp.open(FOLDER + '/special', 'w') try: + f.write('x' * 1024) f.close() stat = sftp.stat(FOLDER + '/special') sftp.chmod(FOLDER + '/special', (stat.st_mode & ~0777) | 0600) - self.assertEqual(sftp.stat(FOLDER + '/special').st_mode & 0777, 0600) + stat = sftp.stat(FOLDER + '/special') + expected_mode = 0600 + if sys.platform == 'win32': + # chmod not really functional on windows + expected_mode = 0666 + if sys.platform == 'cygwin': + # even worse. + expected_mode = 0644 + self.assertEqual(stat.st_mode & 0777, expected_mode) + self.assertEqual(stat.st_size, 1024) mtime = stat.st_mtime - 3600 atime = stat.st_atime - 1800 sftp.utime(FOLDER + '/special', (atime, mtime)) - nstat = sftp.stat(FOLDER + '/special') - self.assertEqual(nstat.st_mtime, mtime) - self.assertEqual(nstat.st_atime, atime) + stat = sftp.stat(FOLDER + '/special') + self.assertEqual(stat.st_mtime, mtime) + if sys.platform not in ('win32', 'cygwin'): + self.assertEqual(stat.st_atime, atime) # can't really test chown, since we'd have to know a valid uid. + + sftp.truncate(FOLDER + '/special', 512) + stat = sftp.stat(FOLDER + '/special') + self.assertEqual(stat.st_size, 512) finally: sftp.remove(FOLDER + '/special') - def test_9_readline_seek(self): + def test_9_fsetstat(self): + """ + verify that the fsetstat functions (chown, chmod, utime, truncate) + work on open files. + """ + f = sftp.open(FOLDER + '/special', 'w') + try: + f.write('x' * 1024) + f.close() + + f = sftp.open(FOLDER + '/special', 'r+') + stat = f.stat() + f.chmod((stat.st_mode & ~0777) | 0600) + stat = f.stat() + + expected_mode = 0600 + if sys.platform == 'win32': + # chmod not really functional on windows + expected_mode = 0666 + if sys.platform == 'cygwin': + # even worse. + expected_mode = 0644 + self.assertEqual(stat.st_mode & 0777, expected_mode) + self.assertEqual(stat.st_size, 1024) + + mtime = stat.st_mtime - 3600 + atime = stat.st_atime - 1800 + f.utime((atime, mtime)) + stat = f.stat() + self.assertEqual(stat.st_mtime, mtime) + if sys.platform not in ('win32', 'cygwin'): + self.assertEqual(stat.st_atime, atime) + + # can't really test chown, since we'd have to know a valid uid. + + f.truncate(512) + stat = f.stat() + self.assertEqual(stat.st_size, 512) + f.close() + finally: + sftp.remove(FOLDER + '/special') + + def test_A_readline_seek(self): """ create a text file and write a bunch of text into it. then count the lines in the file, and seek around to retreive particular lines. this should @@ -324,7 +388,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove(FOLDER + '/duck.txt') - def test_A_write_seek(self): + def test_B_write_seek(self): """ create a text file, seek back and change part of it, and verify that the changes worked. @@ -344,10 +408,14 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove(FOLDER + '/testing.txt') - def test_B_symlink(self): + def test_C_symlink(self): """ create a symlink and then check that lstat doesn't follow it. """ + if not hasattr(os, "symlink"): + # skip symlink tests on windows + return + f = sftp.open(FOLDER + '/original.txt', 'w') try: f.write('original\n') @@ -387,7 +455,7 @@ class SFTPTest (unittest.TestCase): except: pass - def test_C_flush_seek(self): + def test_D_flush_seek(self): """ verify that buffered writes are automatically flushed on seek. """ @@ -409,183 +477,7 @@ class SFTPTest (unittest.TestCase): except: pass - def test_D_lots_of_files(self): - """ - create a bunch of files over the same session. - """ - global g_big_file_test - if not g_big_file_test: - return - numfiles = 100 - try: - for i in range(numfiles): - f = sftp.open('%s/file%d.txt' % (FOLDER, i), 'w', 1) - f.write('this is file #%d.\n' % i) - f.close() - sftp.chmod('%s/file%d.txt' % (FOLDER, i), 0660) - - # now make sure every file is there, by creating a list of filenmes - # and reading them in random order. - numlist = range(numfiles) - while len(numlist) > 0: - r = numlist[random.randint(0, len(numlist) - 1)] - f = sftp.open('%s/file%d.txt' % (FOLDER, r)) - self.assertEqual(f.readline(), 'this is file #%d.\n' % r) - f.close() - numlist.remove(r) - finally: - for i in range(numfiles): - try: - sftp.remove('%s/file%d.txt' % (FOLDER, i)) - except: - pass - - def test_E_big_file(self): - """ - write a 1MB file with no buffering. - """ - global g_big_file_test - if not g_big_file_test: - return - kblob = (1024 * 'x') - start = time.time() - try: - f = sftp.open('%s/hongry.txt' % FOLDER, 'w') - for n in range(1024): - f.write(kblob) - if n % 128 == 0: - sys.stderr.write('.') - f.close() - sys.stderr.write(' ') - - self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) - end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - - start = time.time() - f = sftp.open('%s/hongry.txt' % FOLDER, 'r') - for n in range(1024): - data = f.read(1024) - self.assertEqual(data, kblob) - f.close() - - end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - finally: - sftp.remove('%s/hongry.txt' % FOLDER) - - def test_F_big_file_pipelined(self): - """ - write a 1MB file, with no linefeeds, using pipelining. - """ - global g_big_file_test - if not g_big_file_test: - return - kblob = (1024 * 'x') - start = time.time() - try: - f = sftp.open('%s/hongry.txt' % FOLDER, 'w') - f.set_pipelined(True) - for n in range(1024): - f.write(kblob) - if n % 128 == 0: - sys.stderr.write('.') - f.close() - sys.stderr.write(' ') - - self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) - end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - - start = time.time() - f = sftp.open('%s/hongry.txt' % FOLDER, 'r') - f.prefetch() - for n in range(1024): - data = f.read(1024) - self.assertEqual(data, kblob) - f.close() - - end = time.time() - sys.stderr.write('%ds ' % round(end - start)) - finally: - sftp.remove('%s/hongry.txt' % FOLDER) - - def test_G_lots_of_prefetching(self): - """ - prefetch a 1MB file a bunch of times, discarding the file object - without using it, to verify that paramiko doesn't get confused. - """ - global g_big_file_test - if not g_big_file_test: - return - kblob = (1024 * 'x') - try: - f = sftp.open('%s/hongry.txt' % FOLDER, 'w') - f.set_pipelined(True) - for n in range(1024): - f.write(kblob) - if n % 128 == 0: - sys.stderr.write('.') - f.close() - sys.stderr.write(' ') - - self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) - - for i in range(10): - f = sftp.open('%s/hongry.txt' % FOLDER, 'r') - f.prefetch() - f = sftp.open('%s/hongry.txt' % FOLDER, 'r') - f.prefetch() - for n in range(1024): - data = f.read(1024) - self.assertEqual(data, kblob) - if n % 128 == 0: - sys.stderr.write('.') - f.close() - sys.stderr.write(' ') - finally: - sftp.remove('%s/hongry.txt' % FOLDER) - - def test_H_big_file_big_buffer(self): - """ - write a 1MB file, with no linefeeds, and a big buffer. - """ - global g_big_file_test - if not g_big_file_test: - return - mblob = (1024 * 1024 * 'x') - try: - f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) - f.write(mblob) - f.close() - - self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) - finally: - sftp.remove('%s/hongry.txt' % FOLDER) - - def test_I_big_file_renegotiate(self): - """ - write a 1MB file, forcing key renegotiation in the middle. - """ - global g_big_file_test - if not g_big_file_test: - return - t = sftp.sock.get_transport() - t.packetizer.REKEY_BYTES = 512 * 1024 - k32blob = (32 * 1024 * 'x') - try: - f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) - for i in xrange(32): - f.write(k32blob) - f.close() - - self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) - self.assertNotEquals(t.H, t.session_id) - finally: - sftp.remove('%s/hongry.txt' % FOLDER) - t.packetizer.REKEY_BYTES = pow(2, 30) - - def test_J_realpath(self): + def test_E_realpath(self): """ test that realpath is returning something non-empty and not an error. @@ -596,7 +488,7 @@ class SFTPTest (unittest.TestCase): self.assert_(len(f) > 0) self.assertEquals(os.path.join(pwd, FOLDER), f) - def test_K_mkdir(self): + def test_F_mkdir(self): """ verify that mkdir/rmdir work. """ @@ -619,7 +511,7 @@ class SFTPTest (unittest.TestCase): except IOError: pass - def test_L_chdir(self): + def test_G_chdir(self): """ verify that chdir/getcwd work. """ @@ -656,7 +548,7 @@ class SFTPTest (unittest.TestCase): except: pass - def test_M_get_put(self): + def test_H_get_put(self): """ verify that get/put work. """ @@ -665,27 +557,33 @@ class SFTPTest (unittest.TestCase): localname = os.tempnam() text = 'All I wanted was a plastic bunny rabbit.\n' - f = open(localname, 'w') + f = open(localname, 'wb') f.write(text) f.close() - sftp.put(localname, FOLDER + '/bunny.txt') + saved_progress = [] + def progress_callback(x, y): + saved_progress.append((x, y)) + sftp.put(localname, FOLDER + '/bunny.txt', progress_callback) f = sftp.open(FOLDER + '/bunny.txt', 'r') self.assertEquals(text, f.read(128)) f.close() + self.assertEquals((41, 41), saved_progress[-1]) os.unlink(localname) localname = os.tempnam() - sftp.get(FOLDER + '/bunny.txt', localname) + saved_progress = [] + sftp.get(FOLDER + '/bunny.txt', localname, progress_callback) - f = open(localname, 'r') + f = open(localname, 'rb') self.assertEquals(text, f.read(128)) f.close() + self.assertEquals((41, 41), saved_progress[-1]) os.unlink(localname) sftp.unlink(FOLDER + '/bunny.txt') - def test_N_check(self): + def test_I_check(self): """ verify that file.check() works against our own server. (it's an sftp extension that we support, and may be the only ones who @@ -698,16 +596,17 @@ class SFTPTest (unittest.TestCase): try: f = sftp.open(FOLDER + '/kitty.txt', 'r') sum = f.check('sha1') - self.assertEquals('91059CFC6615941378D413CB5ADAF4C5EB293402', paramiko.util.hexify(sum)) + self.assertEquals('91059CFC6615941378D413CB5ADAF4C5EB293402', hexlify(sum).upper()) sum = f.check('md5', 0, 512) - self.assertEquals('93DE4788FCA28D471516963A1FE3856A', paramiko.util.hexify(sum)) + self.assertEquals('93DE4788FCA28D471516963A1FE3856A', hexlify(sum).upper()) sum = f.check('md5', 0, 0, 510) self.assertEquals('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6', - paramiko.util.hexify(sum)) + hexlify(sum).upper()) + f.close() finally: sftp.unlink(FOLDER + '/kitty.txt') - def test_O_x_flag(self): + def test_J_x_flag(self): """ verify that the 'x' flag works when opening a file. """ @@ -723,7 +622,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.unlink(FOLDER + '/unusual.txt') - def test_P_utf8(self): + def test_K_utf8(self): """ verify that unicode strings are encoded into utf8 correctly. """ @@ -738,3 +637,43 @@ class SFTPTest (unittest.TestCase): self.fail('exception ' + e) sftp.unlink(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65') + def test_L_bad_readv(self): + """ + verify that readv at the end of the file doesn't essplode. + """ + f = sftp.open(FOLDER + '/zero', 'w') + f.close() + try: + f = sftp.open(FOLDER + '/zero', 'r') + data = f.readv([(0, 12)]) + f.close() + + f = sftp.open(FOLDER + '/zero', 'r') + f.prefetch() + data = f.read(100) + f.close() + finally: + sftp.unlink(FOLDER + '/zero') + + def XXX_test_M_seek_append(self): + """ + verify that seek does't affect writes during append. + + does not work except through paramiko. :( openssh fails. + """ + f = sftp.open(FOLDER + '/append.txt', 'a') + try: + f.write('first line\nsecond line\n') + f.seek(11, f.SEEK_SET) + f.write('third line\n') + f.close() + + f = sftp.open(FOLDER + '/append.txt', 'r') + self.assertEqual(f.stat().st_size, 34) + self.assertEqual(f.readline(), 'first line\n') + self.assertEqual(f.readline(), 'second line\n') + self.assertEqual(f.readline(), 'third line\n') + f.close() + finally: + sftp.remove(FOLDER + '/append.txt') + diff --git a/tests/test_sftp_big.py b/tests/test_sftp_big.py new file mode 100644 index 0000000..c182762 --- /dev/null +++ b/tests/test_sftp_big.py @@ -0,0 +1,385 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +some unit tests to make sure sftp works well with large files. + +a real actual sftp server is contacted, and a new folder is created there to +do test file operations in (so no existing files will be harmed). +""" + +import logging +import os +import random +import struct +import sys +import threading +import time +import unittest + +import paramiko +from stub_sftp import StubServer, StubSFTPServer +from loop import LoopSocket +from test_sftp import get_sftp + +FOLDER = os.environ.get('TEST_FOLDER', 'temp-testing000') + + +class BigSFTPTest (unittest.TestCase): + + def setUp(self): + global FOLDER + sftp = get_sftp() + for i in xrange(1000): + FOLDER = FOLDER[:-3] + '%03d' % i + try: + sftp.mkdir(FOLDER) + break + except (IOError, OSError): + pass + + def tearDown(self): + sftp = get_sftp() + sftp.rmdir(FOLDER) + + def test_1_lots_of_files(self): + """ + create a bunch of files over the same session. + """ + sftp = get_sftp() + numfiles = 100 + try: + for i in range(numfiles): + f = sftp.open('%s/file%d.txt' % (FOLDER, i), 'w', 1) + f.write('this is file #%d.\n' % i) + f.close() + sftp.chmod('%s/file%d.txt' % (FOLDER, i), 0660) + + # now make sure every file is there, by creating a list of filenmes + # and reading them in random order. + numlist = range(numfiles) + while len(numlist) > 0: + r = numlist[random.randint(0, len(numlist) - 1)] + f = sftp.open('%s/file%d.txt' % (FOLDER, r)) + self.assertEqual(f.readline(), 'this is file #%d.\n' % r) + f.close() + numlist.remove(r) + finally: + for i in range(numfiles): + try: + sftp.remove('%s/file%d.txt' % (FOLDER, i)) + except: + pass + + def test_2_big_file(self): + """ + write a 1MB file with no buffering. + """ + sftp = get_sftp() + kblob = (1024 * 'x') + start = time.time() + try: + f = sftp.open('%s/hongry.txt' % FOLDER, 'w') + for n in range(1024): + f.write(kblob) + if n % 128 == 0: + sys.stderr.write('.') + f.close() + sys.stderr.write(' ') + + self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) + + start = time.time() + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + for n in range(1024): + data = f.read(1024) + self.assertEqual(data, kblob) + f.close() + + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) + finally: + sftp.remove('%s/hongry.txt' % FOLDER) + + def test_3_big_file_pipelined(self): + """ + write a 1MB file, with no linefeeds, using pipelining. + """ + sftp = get_sftp() + kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) + start = time.time() + try: + f = sftp.open('%s/hongry.txt' % FOLDER, 'w') + f.set_pipelined(True) + for n in range(1024): + f.write(kblob) + if n % 128 == 0: + sys.stderr.write('.') + f.close() + sys.stderr.write(' ') + + self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) + + start = time.time() + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + f.prefetch() + + # read on odd boundaries to make sure the bytes aren't getting scrambled + n = 0 + k2blob = kblob + kblob + chunk = 629 + size = 1024 * 1024 + while n < size: + if n + chunk > size: + chunk = size - n + data = f.read(chunk) + offset = n % 1024 + self.assertEqual(data, k2blob[offset:offset + chunk]) + n += chunk + f.close() + + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) + finally: + sftp.remove('%s/hongry.txt' % FOLDER) + + def test_4_prefetch_seek(self): + sftp = get_sftp() + kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) + try: + f = sftp.open('%s/hongry.txt' % FOLDER, 'w') + f.set_pipelined(True) + for n in range(1024): + f.write(kblob) + if n % 128 == 0: + sys.stderr.write('.') + f.close() + sys.stderr.write(' ') + + self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + + start = time.time() + k2blob = kblob + kblob + chunk = 793 + for i in xrange(10): + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + f.prefetch() + base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) + offsets = [base_offset + j * chunk for j in xrange(100)] + # randomly seek around and read them out + for j in xrange(100): + offset = offsets[random.randint(0, len(offsets) - 1)] + offsets.remove(offset) + f.seek(offset) + data = f.read(chunk) + n_offset = offset % 1024 + self.assertEqual(data, k2blob[n_offset:n_offset + chunk]) + offset += chunk + f.close() + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) + finally: + sftp.remove('%s/hongry.txt' % FOLDER) + + def test_5_readv_seek(self): + sftp = get_sftp() + kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) + try: + f = sftp.open('%s/hongry.txt' % FOLDER, 'w') + f.set_pipelined(True) + for n in range(1024): + f.write(kblob) + if n % 128 == 0: + sys.stderr.write('.') + f.close() + sys.stderr.write(' ') + + self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + + start = time.time() + k2blob = kblob + kblob + chunk = 793 + for i in xrange(10): + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) + # make a bunch of offsets and put them in random order + offsets = [base_offset + j * chunk for j in xrange(100)] + readv_list = [] + for j in xrange(100): + o = offsets[random.randint(0, len(offsets) - 1)] + offsets.remove(o) + readv_list.append((o, chunk)) + ret = f.readv(readv_list) + for i in xrange(len(readv_list)): + offset = readv_list[i][0] + n_offset = offset % 1024 + self.assertEqual(ret.next(), k2blob[n_offset:n_offset + chunk]) + f.close() + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) + finally: + sftp.remove('%s/hongry.txt' % FOLDER) + + def test_6_lots_of_prefetching(self): + """ + prefetch a 1MB file a bunch of times, discarding the file object + without using it, to verify that paramiko doesn't get confused. + """ + sftp = get_sftp() + kblob = (1024 * 'x') + try: + f = sftp.open('%s/hongry.txt' % FOLDER, 'w') + f.set_pipelined(True) + for n in range(1024): + f.write(kblob) + if n % 128 == 0: + sys.stderr.write('.') + f.close() + sys.stderr.write(' ') + + self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + + for i in range(10): + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + f.prefetch() + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + f.prefetch() + for n in range(1024): + data = f.read(1024) + self.assertEqual(data, kblob) + if n % 128 == 0: + sys.stderr.write('.') + f.close() + sys.stderr.write(' ') + finally: + sftp.remove('%s/hongry.txt' % FOLDER) + + def test_7_prefetch_readv(self): + """ + verify that prefetch and readv don't conflict with each other. + """ + sftp = get_sftp() + kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) + try: + f = sftp.open('%s/hongry.txt' % FOLDER, 'w') + f.set_pipelined(True) + for n in range(1024): + f.write(kblob) + if n % 128 == 0: + sys.stderr.write('.') + f.close() + sys.stderr.write(' ') + + self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + f.prefetch() + data = f.read(1024) + self.assertEqual(data, kblob) + + chunk_size = 793 + base_offset = 512 * 1024 + k2blob = kblob + kblob + chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)] + for data in f.readv(chunks): + offset = base_offset % 1024 + self.assertEqual(chunk_size, len(data)) + self.assertEqual(k2blob[offset:offset + chunk_size], data) + base_offset += chunk_size + + f.close() + sys.stderr.write(' ') + finally: + sftp.remove('%s/hongry.txt' % FOLDER) + + def test_8_large_readv(self): + """ + verify that a very large readv is broken up correctly and still + returned as a single blob. + """ + sftp = get_sftp() + kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) + try: + f = sftp.open('%s/hongry.txt' % FOLDER, 'w') + f.set_pipelined(True) + for n in range(1024): + f.write(kblob) + if n % 128 == 0: + sys.stderr.write('.') + f.close() + sys.stderr.write(' ') + + self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + data = list(f.readv([(23 * 1024, 128 * 1024)])) + self.assertEqual(1, len(data)) + data = data[0] + self.assertEqual(128 * 1024, len(data)) + + f.close() + sys.stderr.write(' ') + finally: + sftp.remove('%s/hongry.txt' % FOLDER) + + def test_9_big_file_big_buffer(self): + """ + write a 1MB file, with no linefeeds, and a big buffer. + """ + sftp = get_sftp() + mblob = (1024 * 1024 * 'x') + try: + f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) + f.write(mblob) + f.close() + + self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + finally: + sftp.remove('%s/hongry.txt' % FOLDER) + + def test_A_big_file_renegotiate(self): + """ + write a 1MB file, forcing key renegotiation in the middle. + """ + sftp = get_sftp() + t = sftp.sock.get_transport() + t.packetizer.REKEY_BYTES = 512 * 1024 + k32blob = (32 * 1024 * 'x') + try: + f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) + for i in xrange(32): + f.write(k32blob) + f.close() + + self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + self.assertNotEquals(t.H, t.session_id) + + # try to read it too. + f = sftp.open('%s/hongry.txt' % FOLDER, 'r', 128 * 1024) + f.prefetch() + total = 0 + while total < 1024 * 1024: + total += len(f.read(32 * 1024)) + f.close() + finally: + sftp.remove('%s/hongry.txt' % FOLDER) + t.packetizer.REKEY_BYTES = pow(2, 30) diff --git a/tests/test_transport.py b/tests/test_transport.py index 5fcc786..4b52c4f 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # @@ -20,12 +20,21 @@ Some unit tests for the ssh2 protocol in Transport. """ -import sys, time, threading, unittest +from binascii import hexlify, unhexlify import select +import socket +import sys +import time +import threading +import unittest +import random + from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ - SSHException, BadAuthenticationType, InteractiveQuery, util + SSHException, BadAuthenticationType, InteractiveQuery, ChannelException from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL -from paramiko import OPEN_SUCCEEDED +from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED +from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST +from paramiko.message import Message from loop import LoopSocket @@ -37,50 +46,16 @@ class NullServer (ServerInterface): def get_allowed_auths(self, username): if username == 'slowdive': return 'publickey,password' - if username == 'paranoid': - if not self.paranoid_did_password and not self.paranoid_did_public_key: - return 'publickey,password' - elif self.paranoid_did_password: - return 'publickey' - else: - return 'password' - if username == 'commie': - return 'keyboard-interactive' return 'publickey' def check_auth_password(self, username, password): if (username == 'slowdive') and (password == 'pygmalion'): return AUTH_SUCCESSFUL - if (username == 'paranoid') and (password == 'paranoid'): - # 2-part auth (even openssh doesn't support this) - self.paranoid_did_password = True - if self.paranoid_did_public_key: - return AUTH_SUCCESSFUL - return AUTH_PARTIALLY_SUCCESSFUL - return AUTH_FAILED - - def check_auth_publickey(self, username, key): - if (username == 'paranoid') and (key == self.paranoid_key): - # 2-part auth - self.paranoid_did_public_key = True - if self.paranoid_did_password: - return AUTH_SUCCESSFUL - return AUTH_PARTIALLY_SUCCESSFUL - return AUTH_FAILED - - def check_auth_interactive(self, username, submethods): - if username == 'commie': - self.username = username - return InteractiveQuery('password', 'Please enter a password.', ('Password', False)) - return AUTH_FAILED - - def check_auth_interactive_response(self, responses): - if self.username == 'commie': - if (len(responses) == 1) and (responses[0] == 'cat'): - return AUTH_SUCCESSFUL return AUTH_FAILED def check_channel_request(self, kind, chanid): + if kind == 'bogus': + return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED return OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): @@ -94,10 +69,34 @@ class NullServer (ServerInterface): def check_global_request(self, kind, msg): self._global_request = kind return False + + def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): + self._x11_single_connection = single_connection + self._x11_auth_protocol = auth_protocol + self._x11_auth_cookie = auth_cookie + self._x11_screen_number = screen_number + return True + + def check_port_forward_request(self, addr, port): + self._listen = socket.socket() + self._listen.bind(('127.0.0.1', 0)) + self._listen.listen(1) + return self._listen.getsockname()[1] + + def cancel_port_forward_request(self, addr, port): + self._listen.close() + self._listen = None + + def check_channel_direct_tcpip_request(self, chanid, origin, destination): + self._tcpip_dest = destination + return OPEN_SUCCEEDED class TransportTest (unittest.TestCase): + assertTrue = unittest.TestCase.failUnless # for Python 2.3 and below + assertFalse = unittest.TestCase.failIf # for Python 2.3 and below + def setUp(self): self.socks = LoopSocket() self.sockc = LoopSocket() @@ -111,6 +110,26 @@ class TransportTest (unittest.TestCase): self.socks.close() self.sockc.close() + def setup_test_server(self, client_options=None, server_options=None): + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + + if client_options is not None: + client_options(self.tc.get_security_options()) + if server_options is not None: + server_options(self.ts.get_security_options()) + + event = threading.Event() + self.server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, self.server) + self.tc.connect(hostkey=public_host_key, + username='slowdive', password='pygmalion') + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) @@ -130,11 +149,11 @@ class TransportTest (unittest.TestCase): def test_2_compute_key(self): self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L - self.tc.H = util.unhexify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') + self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.session_id = self.tc.H key = self.tc._compute_key('C', 32) self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', - util.hexify(key)) + hexlify(key).upper()) def test_3_simple(self): """ @@ -168,193 +187,45 @@ class TransportTest (unittest.TestCase): verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - options = self.tc.get_security_options() - options.ciphers = ('aes256-cbc',) - options.digests = ('hmac-md5-96',) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + def force_algorithms(options): + options.ciphers = ('aes256-cbc',) + options.digests = ('hmac-md5-96',) + self.setup_test_server(client_options=force_algorithms) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) - self.assert_(self.tc.renegotiate_keys()) + self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): """ verify that the keepalive will be sent. """ - self.tc.set_hexdump(True) - - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - self.assertEquals(None, getattr(server, '_global_request', None)) + self.setup_test_server() + self.assertEquals(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) - self.assertEquals('keepalive@lag.net', server._global_request) - - def test_6_bad_auth_type(self): - """ - verify that we get the right exception when an unsupported auth - type is requested. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - try: - self.tc.connect(hostkey=public_host_key, - username='unknown', password='error') - self.assert_(False) - except: - etype, evalue, etb = sys.exc_info() - self.assertEquals(BadAuthenticationType, etype) - self.assertEquals(['publickey'], evalue.allowed_types) - - def test_7_bad_password(self): - """ - verify that a bad password gets the right exception, and that a retry - with the right password works. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - try: - self.tc.auth_password(username='slowdive', password='error') - self.assert_(False) - except: - etype, evalue, etb = sys.exc_info() - self.assertEquals(SSHException, etype) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_8_multipart_auth(self): - """ - verify that multipart auth works. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - remain = self.tc.auth_password(username='paranoid', password='paranoid') - self.assertEquals(['publickey'], remain) - key = DSSKey.from_private_key_file('tests/test_dss.key') - remain = self.tc.auth_publickey(username='paranoid', key=key) - self.assertEquals([], remain) - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_9_interactive_auth(self): - """ - verify keyboard-interactive auth works. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - - def handler(title, instructions, prompts): - self.got_title = title - self.got_instructions = instructions - self.got_prompts = prompts - return ['cat'] - remain = self.tc.auth_interactive('commie', handler) - self.assertEquals(self.got_title, 'password') - self.assertEquals(self.got_prompts, [('Password', False)]) - self.assertEquals([], remain) - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + self.assertEquals('keepalive@lag.net', self.server._global_request) - def test_A_interactive_auth_fallback(self): - """ - verify that a password auth attempt will fallback to "interactive" - if password auth isn't supported but interactive is. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - remain = self.tc.auth_password('commie', 'cat') - self.assertEquals([], remain) - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_B_exec_command(self): + def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) - self.assert_(not chan.exec_command('no')) + try: + chan.exec_command('no') + self.assert_(False) + except SSHException, x: + pass chan = self.tc.open_session() - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') @@ -369,7 +240,7 @@ class TransportTest (unittest.TestCase): # now try it with combined stdout/stderr chan = self.tc.open_session() - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') @@ -381,26 +252,13 @@ class TransportTest (unittest.TestCase): self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) - def test_C_invoke_shell(self): + def test_7_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() chan = self.tc.open_session() - self.assert_(chan.invoke_shell()) + chan.invoke_shell() schan = self.ts.accept(1.0) chan.send('communist j. cat\n') f = schan.makefile() @@ -408,28 +266,28 @@ class TransportTest (unittest.TestCase): chan.close() self.assertEquals('', f.readline()) - def test_D_exit_status(self): + def test_8_channel_exception(self): + """ + verify that ChannelException is thrown for a bad open-channel request. + """ + self.setup_test_server() + try: + chan = self.tc.open_channel('bogus') + self.fail('expected exception') + except ChannelException, x: + self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) + + def test_9_exit_status(self): """ verify that get_exit_status() works. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan.send('Hello there.\n') + self.assert_(not chan.exit_status_ready()) # trigger an EOF schan.shutdown_read() schan.shutdown_write() @@ -439,29 +297,22 @@ class TransportTest (unittest.TestCase): f = chan.makefile() self.assertEquals('Hello there.\n', f.readline()) self.assertEquals('', f.readline()) + count = 0 + while not chan.exit_status_ready(): + time.sleep(0.1) + count += 1 + if count > 50: + raise Exception("timeout") self.assertEquals(23, chan.recv_exit_status()) chan.close() - def test_E_select(self): + def test_A_select(self): """ verify that select() on a channel works. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() chan = self.tc.open_session() - self.assert_(chan.invoke_shell()) + chan.invoke_shell() schan = self.ts.accept(1.0) # nothing should be ready @@ -503,28 +354,21 @@ class TransportTest (unittest.TestCase): self.assertEquals([], e) self.assertEquals('', chan.recv(16)) + # make sure the pipe is still open for now... + p = chan._pipe + self.assertEquals(False, p._closed) chan.close() + # ...and now is closed. + self.assertEquals(True, p._closed) - def test_F_renegotiate(self): + def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 - chan = self.tc.open_session() - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan = self.ts.accept(1.0) self.assertEquals(self.tc.H, self.tc.session_id) @@ -541,26 +385,15 @@ class TransportTest (unittest.TestCase): schan.close() - def test_G_compression(self): + def test_C_compression(self): """ verify that zlib compression is basically working. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - self.ts.get_security_options().compression = ('zlib',) - self.tc.get_security_options().compression = ('zlib',) - event = threading.Event() - server = NullServer() - self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + def force_compression(o): + o.compression = ('zlib',) + self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() - self.assert_(chan.exec_command('yes')) + chan.exec_command('yes') schan = self.ts.accept(1.0) bytes = self.tc.packetizer._Packetizer__sent_bytes @@ -568,6 +401,326 @@ class TransportTest (unittest.TestCase): bytes2 = self.tc.packetizer._Packetizer__sent_bytes # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) self.assert_(bytes2 - bytes < 1024) + self.assertEquals(52, bytes2 - bytes) chan.close() schan.close() + + def test_D_x11(self): + """ + verify that an x11 port can be requested and opened. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + requested = [] + def handler(c, (addr, port)): + requested.append((addr, port)) + self.tc._queue_incoming_channel(c) + + self.assertEquals(None, getattr(self.server, '_x11_screen_number', None)) + cookie = chan.request_x11(0, single_connection=True, handler=handler) + self.assertEquals(0, self.server._x11_screen_number) + self.assertEquals('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) + self.assertEquals(cookie, self.server._x11_auth_cookie) + self.assertEquals(True, self.server._x11_single_connection) + + x11_server = self.ts.open_x11_channel(('localhost', 6093)) + x11_client = self.tc.accept() + self.assertEquals('localhost', requested[0][0]) + self.assertEquals(6093, requested[0][1]) + + x11_server.send('hello') + self.assertEquals('hello', x11_client.recv(5)) + + x11_server.close() + x11_client.close() + chan.close() + schan.close() + + def test_E_reverse_port_forwarding(self): + """ + verify that a client can ask the server to open a reverse port for + forwarding. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + requested = [] + def handler(c, (origin_addr, origin_port), (server_addr, server_port)): + requested.append((origin_addr, origin_port)) + requested.append((server_addr, server_port)) + self.tc._queue_incoming_channel(c) + + port = self.tc.request_port_forward('127.0.0.1', 0, handler) + self.assertEquals(port, self.server._listen.getsockname()[1]) + + cs = socket.socket() + cs.connect(('127.0.0.1', port)) + ss, _ = self.server._listen.accept() + sch = self.ts.open_forwarded_tcpip_channel(ss.getsockname(), ss.getpeername()) + cch = self.tc.accept() + + sch.send('hello') + self.assertEquals('hello', cch.recv(5)) + sch.close() + cch.close() + ss.close() + cs.close() + + # now cancel it. + self.tc.cancel_port_forward('127.0.0.1', port) + self.assertTrue(self.server._listen is None) + + def test_F_port_forwarding(self): + """ + verify that a client can forward new connections from a locally- + forwarded port. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + # open a port on the "server" that the client will ask to forward to. + greeting_server = socket.socket() + greeting_server.bind(('127.0.0.1', 0)) + greeting_server.listen(1) + greeting_port = greeting_server.getsockname()[1] + + cs = self.tc.open_channel('direct-tcpip', ('127.0.0.1', greeting_port), ('', 9000)) + sch = self.ts.accept(1.0) + cch = socket.socket() + cch.connect(self.server._tcpip_dest) + + ss, _ = greeting_server.accept() + ss.send('Hello!\n') + ss.close() + sch.send(cch.recv(8192)) + sch.close() + + self.assertEquals('Hello!\n', cs.recv(7)) + cs.close() + + def test_G_stderr_select(self): + """ + verify that select() on a channel works even if only stderr is + receiving data. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.invoke_shell() + schan = self.ts.accept(1.0) + + # nothing should be ready + r, w, e = select.select([chan], [], [], 0.1) + self.assertEquals([], r) + self.assertEquals([], w) + self.assertEquals([], e) + + schan.send_stderr('hello\n') + + # something should be ready now (give it 1 second to appear) + for i in range(10): + r, w, e = select.select([chan], [], [], 0.1) + if chan in r: + break + time.sleep(0.1) + self.assertEquals([chan], r) + self.assertEquals([], w) + self.assertEquals([], e) + + self.assertEquals('hello\n', chan.recv_stderr(6)) + + # and, should be dead again now + r, w, e = select.select([chan], [], [], 0.1) + self.assertEquals([], r) + self.assertEquals([], w) + self.assertEquals([], e) + + schan.close() + chan.close() + + def test_H_send_ready(self): + """ + verify that send_ready() indicates when a send would not block. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.invoke_shell() + schan = self.ts.accept(1.0) + + self.assertEquals(chan.send_ready(), True) + total = 0 + K = '*' * 1024 + while total < 1024 * 1024: + chan.send(K) + total += len(K) + if not chan.send_ready(): + break + self.assert_(total < 1024 * 1024) + + schan.close() + chan.close() + self.assertEquals(chan.send_ready(), True) + + def test_I_rekey_deadlock(self): + """ + Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent + + Note: When this test fails, it may leak threads. + """ + + # Test for an obscure deadlocking bug that can occur if we receive + # certain messages while initiating a key exchange. + # + # The deadlock occurs as follows: + # + # In the main thread: + # 1. The user's program calls Channel.send(), which sends + # MSG_CHANNEL_DATA to the remote host. + # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and + # sets the __need_rekey flag. + # + # In the Transport thread: + # 3. Packetizer notices that the __need_rekey flag is set, and raises + # NeedRekeyException. + # 4. In response to NeedRekeyException, the transport thread sends + # MSG_KEXINIT to the remote host. + # + # On the remote host (using any SSH implementation): + # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. + # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. + # + # In the main thread: + # 7. The user's program calls Channel.send(). + # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). + # 9. Transport._send_user_message waits for Transport.clear_to_send + # to be set (i.e., it waits for re-keying to complete). + # Channel.lock is still held. + # + # In the Transport thread: + # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust + # is called to handle it. + # 11. Channel._window_adjust tries to acquire Channel.lock, but it + # blocks because the lock is already held by the main thread. + # + # The result is that the Transport thread never processes the remote + # host's MSG_KEXINIT packet, because it becomes deadlocked while + # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. + + # We set up two separate threads for sending and receiving packets, + # while the main thread acts as a watchdog timer. If the timer + # expires, a deadlock is assumed. + + class SendThread(threading.Thread): + def __init__(self, chan, iterations, done_event): + threading.Thread.__init__(self, None, None, self.__class__.__name__) + self.setDaemon(True) + self.chan = chan + self.iterations = iterations + self.done_event = done_event + self.watchdog_event = threading.Event() + self.last = None + + def run(self): + try: + for i in xrange(1, 1+self.iterations): + if self.done_event.isSet(): + break + self.watchdog_event.set() + #print i, "SEND" + self.chan.send("x" * 2048) + finally: + self.done_event.set() + self.watchdog_event.set() + + class ReceiveThread(threading.Thread): + def __init__(self, chan, done_event): + threading.Thread.__init__(self, None, None, self.__class__.__name__) + self.setDaemon(True) + self.chan = chan + self.done_event = done_event + self.watchdog_event = threading.Event() + + def run(self): + try: + while not self.done_event.isSet(): + if self.chan.recv_ready(): + chan.recv(65536) + self.watchdog_event.set() + else: + if random.randint(0, 1): + time.sleep(random.randint(0, 500) / 1000.0) + finally: + self.done_event.set() + self.watchdog_event.set() + + self.setup_test_server() + self.ts.packetizer.REKEY_BYTES = 2048 + + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + # Monkey patch the client's Transport._handler_table so that the client + # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial + # MSG_KEXINIT. This is used to simulate the effect of network latency + # on a real MSG_CHANNEL_WINDOW_ADJUST message. + self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary + _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] + def _negotiate_keys_wrapper(self, m): + if self.local_kex_init is None: # Remote side sent KEXINIT + # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it + # before responding to the incoming MSG_KEXINIT. + m2 = Message() + m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) + m2.add_int(chan.remote_chanid) + m2.add_int(1) # bytes to add + self._send_message(m2) + return _negotiate_keys(self, m) + self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper + + # Parameters for the test + iterations = 500 # The deadlock does not happen every time, but it + # should after many iterations. + timeout = 5 + + # This event is set when the test is completed + done_event = threading.Event() + + # Start the sending thread + st = SendThread(schan, iterations, done_event) + st.start() + + # Start the receiving thread + rt = ReceiveThread(chan, done_event) + rt.start() + + # Act as a watchdog timer, checking + deadlocked = False + while not deadlocked and not done_event.isSet(): + for event in (st.watchdog_event, rt.watchdog_event): + event.wait(timeout) + if done_event.isSet(): + break + if not event.isSet(): + deadlocked = True + break + event.clear() + + # Tell the threads to stop (if they haven't already stopped). Note + # that if one or more threads are deadlocked, they might hang around + # forever (until the process exits). + done_event.set() + + # Assertion: We must not have detected a timeout. + self.assertFalse(deadlocked) + + # Close the channels + schan.close() + chan.close() diff --git a/tests/test_util.py b/tests/test_util.py index fa8c029..d385bab 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,6 +1,4 @@ -#!/usr/bin/python - -# Copyright (C) 2003-2005 Robey Pointer +# Copyright (C) 2003-2007 Robey Pointer # # This file is part of paramiko. # @@ -22,7 +20,9 @@ Some unit tests for utility functions. """ +from binascii import hexlify import cStringIO +import os import unittest from Crypto.Hash import SHA import paramiko.util @@ -43,27 +43,80 @@ Host spoo.example.com Crazy something else """ +test_hosts_file = """\ +secure.example.com ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA1PD6U2/TVxET6lkpKhOk5r\ +9q/kAYG6sP9f5zuUYP8i7FOFp/6ncCEbbtg/lB+A3iidyxoSWl+9jtoyyDOOVX4UIDV9G11Ml8om3\ +D+jrpI9cycZHqilK0HmxDeCuxbwyMuaCygU9gS2qoRvNLWZk70OpIKSSpBo0Wl3/XUmz9uhc= +happy.example.com ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31M\ +BGQ3GQ/Fc7SX6gkpXkwcZryoi4kNFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW\ +5ymME3bQ4J/k1IKxCtz/bAlAqFgKoc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M= +""" + + +# for test 1: +from paramiko import * + class UtilTest (unittest.TestCase): - K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504L + assertTrue = unittest.TestCase.failUnless # for Python 2.3 and below + assertFalse = unittest.TestCase.failIf # for Python 2.3 and below def setUp(self): pass def tearDown(self): pass + + def test_1_import(self): + """ + verify that all the classes can be imported from paramiko. + """ + symbols = globals().keys() + self.assertTrue('Transport' in symbols) + self.assertTrue('SSHClient' in symbols) + self.assertTrue('MissingHostKeyPolicy' in symbols) + self.assertTrue('AutoAddPolicy' in symbols) + self.assertTrue('RejectPolicy' in symbols) + self.assertTrue('WarningPolicy' in symbols) + self.assertTrue('SecurityOptions' in symbols) + self.assertTrue('SubsystemHandler' in symbols) + self.assertTrue('Channel' in symbols) + self.assertTrue('RSAKey' in symbols) + self.assertTrue('DSSKey' in symbols) + self.assertTrue('Message' in symbols) + self.assertTrue('SSHException' in symbols) + self.assertTrue('AuthenticationException' in symbols) + self.assertTrue('PasswordRequiredException' in symbols) + self.assertTrue('BadAuthenticationType' in symbols) + self.assertTrue('ChannelException' in symbols) + self.assertTrue('SFTP' in symbols) + self.assertTrue('SFTPFile' in symbols) + self.assertTrue('SFTPHandle' in symbols) + self.assertTrue('SFTPClient' in symbols) + self.assertTrue('SFTPServer' in symbols) + self.assertTrue('SFTPError' in symbols) + self.assertTrue('SFTPAttributes' in symbols) + self.assertTrue('SFTPServerInterface' in symbols) + self.assertTrue('ServerInterface' in symbols) + self.assertTrue('BufferedFile' in symbols) + self.assertTrue('Agent' in symbols) + self.assertTrue('AgentKey' in symbols) + self.assertTrue('HostKeys' in symbols) + self.assertTrue('SSHConfig' in symbols) + self.assertTrue('util' in symbols) - def test_1_parse_config(self): + def test_2_parse_config(self): global test_config_file f = cStringIO.StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) - self.assertEquals(config, [ {'identityfile': '~/.ssh/id_rsa', 'host': '*', 'user': 'robey', - 'crazy': 'something dumb '}, - {'host': '*.example.com', 'user': 'bjork', 'port': '3333'}, - {'host': 'spoo.example.com', 'crazy': 'something else'}]) + self.assertEquals(config._config, + [ {'identityfile': '~/.ssh/id_rsa', 'host': '*', 'user': 'robey', + 'crazy': 'something dumb '}, + {'host': '*.example.com', 'user': 'bjork', 'port': '3333'}, + {'host': 'spoo.example.com', 'crazy': 'something else'}]) - def test_2_host_config(self): + def test_3_host_config(self): global test_config_file f = cStringIO.StringIO(test_config_file) config = paramiko.util.parse_ssh_config(f) @@ -74,7 +127,28 @@ class UtilTest (unittest.TestCase): c = paramiko.util.lookup_ssh_host_config('spoo.example.com', config) self.assertEquals(c, {'identityfile': '~/.ssh/id_rsa', 'user': 'bjork', 'crazy': 'something else', 'port': '3333'}) - def test_3_generate_key_bytes(self): + def test_4_generate_key_bytes(self): x = paramiko.util.generate_key_bytes(SHA, 'ABCDEFGH', 'This is my secret passphrase.', 64) hex = ''.join(['%02x' % ord(c) for c in x]) self.assertEquals(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b') + + def test_5_host_keys(self): + f = open('hostfile.temp', 'w') + f.write(test_hosts_file) + f.close() + try: + hostdict = paramiko.util.load_host_keys('hostfile.temp') + self.assertEquals(2, len(hostdict)) + self.assertEquals(1, len(hostdict.values()[0])) + self.assertEquals(1, len(hostdict.values()[1])) + fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() + self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) + finally: + os.unlink('hostfile.temp') + + def test_6_random(self): + from paramiko.common import randpool + # just verify that we can pull out 32 bytes and not get an exception. + x = randpool.get_bytes(32) + self.assertEquals(len(x), 32) + -- cgit v1.2.3