# Copyright (C) 2003-2009 Robey Pointer # # This file is part of paramiko. # # Paramiko is free software; you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation; either version 2.1 of the License, or (at your option) # any later version. # # Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more # details. # # You should have received a copy of the GNU Lesser General Public License # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. """ Some unit tests for SSHClient. """ from __future__ import with_statement # Python 2.5 support import socket import threading import time import unittest import weakref import warnings import os from binascii import hexlify import paramiko class NullServer (paramiko.ServerInterface): def get_allowed_auths(self, username): if username == 'slowdive': return 'publickey,password' return 'publickey' def check_auth_password(self, username, password): if (username == 'slowdive') and (password == 'pygmalion'): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED def check_auth_publickey(self, username, key): if (key.get_name() == 'ssh-dss') and (hexlify(key.get_fingerprint()) == '4478f0b9a23cc5182009ff755bc1d26c'): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED def check_channel_request(self, kind, chanid): return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): if command != 'yes': return False return True class SSHClientTest (unittest.TestCase): def setUp(self): self.sockl = socket.socket() self.sockl.bind(('localhost', 0)) self.sockl.listen(1) self.addr, self.port = self.sockl.getsockname() self.event = threading.Event() thread = threading.Thread(target=self._run) thread.start() def tearDown(self): for attr in "tc ts socks sockl".split(): if hasattr(self, attr): getattr(self, attr).close() def _run(self): self.socks, addr = self.sockl.accept() self.ts = paramiko.Transport(self.socks) host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') self.ts.add_server_key(host_key) server = NullServer() self.ts.start_server(self.event, server) def test_1_client(self): """ verify that the SSHClient stuff works too. """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) self.tc = paramiko.SSHClient() self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') self.event.wait(1.0) self.assert_(self.event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.ts.is_authenticated()) stdin, stdout, stderr = self.tc.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() self.assertEquals('Hello there.\n', stdout.readline()) self.assertEquals('', stdout.readline()) self.assertEquals('This is on stderr.\n', stderr.readline()) self.assertEquals('', stderr.readline()) stdin.close() stdout.close() stderr.close() def test_2_client_dsa(self): """ verify that SSHClient works with a DSA key. """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) self.tc = paramiko.SSHClient() self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.connect(self.addr, self.port, username='slowdive', key_filename='tests/test_dss.key') self.event.wait(1.0) self.assert_(self.event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.ts.is_authenticated()) stdin, stdout, stderr = self.tc.exec_command('yes') schan = self.ts.accept(1.0) schan.send('Hello there.\n') schan.send_stderr('This is on stderr.\n') schan.close() self.assertEquals('Hello there.\n', stdout.readline()) self.assertEquals('', stdout.readline()) self.assertEquals('This is on stderr.\n', stderr.readline()) self.assertEquals('', stderr.readline()) stdin.close() stdout.close() stderr.close() def test_3_multiple_key_files(self): """ verify that SSHClient accepts and tries multiple key files. """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) self.tc = paramiko.SSHClient() self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ]) self.event.wait(1.0) self.assert_(self.event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.ts.is_authenticated()) def test_4_auto_add_policy(self): """ verify that SSHClient's AutoAddPolicy works. """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') self.event.wait(1.0) self.assert_(self.event.isSet()) self.assert_(self.ts.is_active()) self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.ts.is_authenticated()) self.assertEquals(1, len(self.tc.get_host_keys())) self.assertEquals(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa']) def test_5_save_host_keys(self): """ verify that SSHClient correctly saves a known_hosts file. """ warnings.filterwarnings('ignore', 'tempnam.*') host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) localname = os.tempnam() client = paramiko.SSHClient() self.assertEquals(0, len(client.get_host_keys())) host_id = '[%s]:%d' % (self.addr, self.port) client.get_host_keys().add(host_id, 'ssh-rsa', public_host_key) self.assertEquals(1, len(client.get_host_keys())) self.assertEquals(public_host_key, client.get_host_keys()[host_id]['ssh-rsa']) client.save_host_keys(localname) with open(localname) as fd: assert host_id in fd.read() os.unlink(localname) def test_6_cleanup(self): """ verify that when an SSHClient is collected, its transport (and the transport's packetizer) is closed. """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') self.event.wait(1.0) self.assert_(self.event.isSet()) self.assert_(self.ts.is_active()) p = weakref.ref(self.tc._transport.packetizer) self.assert_(p() is not None) del self.tc # hrm, sometimes p isn't cleared right away. why is that? st = time.time() while (time.time() - st < 5.0) and (p() is not None): time.sleep(0.1) self.assert_(p() is None)