diff options
Diffstat (limited to 'paramiko')
38 files changed, 3501 insertions, 911 deletions
diff --git a/paramiko/__init__.py b/paramiko/__init__.py index 0a312cb..9a8caec 100644 --- a/paramiko/__init__.py +++ b/paramiko/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2008 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -26,8 +26,9 @@ replaced C{telnet} and C{rsh} for secure access to remote shells, but the protocol also includes the ability to open arbitrary channels to remote services across an encrypted tunnel. (This is how C{sftp} works, for example.) -To use this package, pass a socket (or socket-like object) to a L{Transport}, -and use L{start_server <Transport.start_server>} or +The high-level client API starts with creation of an L{SSHClient} object. +For more direct control, pass a socket (or socket-like object) to a +L{Transport}, and use L{start_server <Transport.start_server>} or L{start_client <Transport.start_client>} to negoatite with the remote host as either a server or client. As a client, you are responsible for authenticating using a password or private key, and checking @@ -46,7 +47,7 @@ released under the GNU Lesser General Public License (LGPL). Website: U{http://www.lag.net/paramiko/} -@version: 1.5.2 (rhydon) +@version: 1.7.4 (Desmond) @author: Robey Pointer @contact: robey@lag.net @license: GNU Lesser General Public License (LGPL) @@ -59,20 +60,19 @@ if sys.version_info < (2, 2): __author__ = "Robey Pointer <robey@lag.net>" -__date__ = "04 Dec 2005" -__version__ = "1.5.2 (rhydon)" -__version_info__ = (1, 5, 2) +__date__ = "06 Jul 2008" +__version__ = "1.7.4 (Desmond)" +__version_info__ = (1, 7, 4) __license__ = "GNU Lesser General Public License (LGPL)" -import transport, auth_handler, channel, rsakey, dsskey, message -import ssh_exception, file, packet, agent, server, util -import sftp_client, sftp_attr, sftp_handle, sftp_server, sftp_si - from transport import randpool, SecurityOptions, Transport +from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, WarningPolicy from auth_handler import AuthHandler from channel import Channel, ChannelFile -from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType +from ssh_exception import SSHException, PasswordRequiredException, \ + BadAuthenticationType, ChannelException, BadHostKeyException, \ + AuthenticationException from server import ServerInterface, SubsystemHandler, InteractiveQuery from rsakey import RSAKey from dsskey import DSSKey @@ -88,15 +88,15 @@ from packet import Packetizer from file import BufferedFile from agent import Agent, AgentKey from pkey import PKey +from hostkeys import HostKeys +from config import SSHConfig # fix module names for epydoc -for x in [Transport, SecurityOptions, Channel, SFTPServer, SSHException, \ - PasswordRequiredException, BadAuthenticationType, ChannelFile, \ - SubsystemHandler, AuthHandler, RSAKey, DSSKey, SFTPError, \ - SFTP, SFTPClient, SFTPServer, Message, Packetizer, SFTPAttributes, \ - SFTPHandle, SFTPServerInterface, BufferedFile, Agent, AgentKey, \ - PKey, BaseSFTP, SFTPFile, ServerInterface]: - x.__module__ = 'paramiko' +for c in locals().values(): + if issubclass(type(c), type) or type(c).__name__ == 'classobj': + # classobj for exceptions :/ + c.__module__ = __name__ +del c from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \ OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, \ @@ -106,16 +106,24 @@ from sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, S SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED __all__ = [ 'Transport', + 'SSHClient', + 'MissingHostKeyPolicy', + 'AutoAddPolicy', + 'RejectPolicy', + 'WarningPolicy', 'SecurityOptions', 'SubsystemHandler', 'Channel', + 'PKey', 'RSAKey', 'DSSKey', - 'Agent', 'Message', 'SSHException', + 'AuthenticationException', 'PasswordRequiredException', 'BadAuthenticationType', + 'ChannelException', + 'BadHostKeyException', 'SFTP', 'SFTPFile', 'SFTPHandle', @@ -123,24 +131,11 @@ __all__ = [ 'Transport', 'SFTPServer', 'SFTPError', 'SFTPAttributes', - 'SFTPServerInterface' + 'SFTPServerInterface', 'ServerInterface', 'BufferedFile', 'Agent', 'AgentKey', - 'rsakey', - 'dsskey', - 'pkey', - 'message', - 'transport', - 'sftp', - 'sftp_client', - 'sftp_server', - 'sftp_attr', - 'sftp_file', - 'sftp_si', - 'sftp_handle', - 'server', - 'file', - 'agent', + 'HostKeys', + 'SSHConfig', 'util' ] diff --git a/paramiko/agent.py b/paramiko/agent.py index 3555512..71de8b8 100644 --- a/paramiko/agent.py +++ b/paramiko/agent.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 John Rochester <john@jrochester.org> +# Copyright (C) 2003-2007 John Rochester <john@jrochester.org> # # This file is part of paramiko. # @@ -55,20 +55,33 @@ class Agent: @raise SSHException: if an SSH agent is found, but speaks an incompatible protocol """ + self.keys = () if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - conn.connect(os.environ['SSH_AUTH_SOCK']) + try: + conn.connect(os.environ['SSH_AUTH_SOCK']) + except: + # probably a dangling env var: the ssh agent is gone + return self.conn = conn - type, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES)) - if type != SSH2_AGENT_IDENTITIES_ANSWER: - raise SSHException('could not get keys from ssh-agent') - keys = [] - for i in range(result.get_int()): - keys.append(AgentKey(self, result.get_string())) - result.get_string() - self.keys = tuple(keys) + elif sys.platform == 'win32': + import win_pageant + if win_pageant.can_talk_to_agent(): + self.conn = win_pageant.PageantConnection() + else: + return else: - self.keys = () + # no agent support + return + + ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES)) + if ptype != SSH2_AGENT_IDENTITIES_ANSWER: + raise SSHException('could not get keys from ssh-agent') + keys = [] + for i in range(result.get_int()): + keys.append(AgentKey(self, result.get_string())) + result.get_string() + self.keys = tuple(keys) def close(self): """ @@ -132,7 +145,7 @@ class AgentKey(PKey): msg.add_string(self.blob) msg.add_string(data) msg.add_int(0) - type, result = self.agent._send_message(msg) - if type != SSH2_AGENT_SIGN_RESPONSE: + ptype, result = self.agent._send_message(msg) + if ptype != SSH2_AGENT_SIGN_RESPONSE: raise SSHException('key cannot be used for signing') return result.get_string() diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py index 59aa376..39a0194 100644 --- a/paramiko/auth_handler.py +++ b/paramiko/auth_handler.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -21,6 +21,7 @@ L{AuthHandler} """ import threading +import weakref # this helps freezing utils import encodings.utf_8 @@ -28,7 +29,8 @@ import encodings.utf_8 from paramiko.common import * from paramiko import util from paramiko.message import Message -from paramiko.ssh_exception import SSHException, BadAuthenticationType, PartialAuthentication +from paramiko.ssh_exception import SSHException, AuthenticationException, \ + BadAuthenticationType, PartialAuthentication from paramiko.server import InteractiveQuery @@ -38,13 +40,15 @@ class AuthHandler (object): """ def __init__(self, transport): - self.transport = transport + self.transport = weakref.proxy(transport) self.username = None self.authenticated = False self.auth_event = None self.auth_method = '' self.password = None self.private_key = None + self.interactive_handler = None + self.submethods = None # for server mode: self.auth_username = None self.auth_fail_count = 0 @@ -154,15 +158,15 @@ class AuthHandler (object): event.wait(0.1) if not self.transport.is_active(): e = self.transport.get_exception() - if e is None: - e = SSHException('Authentication failed.') + if (e is None) or issubclass(e.__class__, EOFError): + e = AuthenticationException('Authentication failed.') raise e if event.isSet(): break if not self.is_authenticated(): e = self.transport.get_exception() if e is None: - e = SSHException('Authentication failed.') + e = AuthenticationException('Authentication failed.') # this is horrible. python Exception isn't yet descended from # object, so type(e) won't work. :( if issubclass(e.__class__, PartialAuthentication): @@ -193,7 +197,10 @@ class AuthHandler (object): m.add_string(self.auth_method) if self.auth_method == 'password': m.add_boolean(False) - m.add_string(self.password.encode('UTF-8')) + password = self.password + if isinstance(password, unicode): + password = password.encode('UTF-8') + m.add_string(password) elif self.auth_method == 'publickey': m.add_boolean(True) m.add_string(self.private_key.get_name()) @@ -276,12 +283,22 @@ class AuthHandler (object): result = self.transport.server_object.check_auth_none(username) elif method == 'password': changereq = m.get_boolean() - password = m.get_string().decode('UTF-8', 'replace') + password = m.get_string() + try: + password = password.decode('UTF-8') + except UnicodeError: + # some clients/servers expect non-utf-8 passwords! + # in this case, just return the raw byte string. + pass if changereq: # always treated as failure, since we don't support changing passwords, but collect # the list of valid auth types from the callback anyway self.transport._log(DEBUG, 'Auth request to change passwords (rejected)') - newpassword = m.get_string().decode('UTF-8', 'replace') + newpassword = m.get_string() + try: + newpassword = newpassword.decode('UTF-8', 'replace') + except UnicodeError: + pass result = AUTH_FAILED else: result = self.transport.server_object.check_auth_password(username, password) @@ -332,7 +349,7 @@ class AuthHandler (object): self._send_auth_result(username, method, result) def _parse_userauth_success(self, m): - self.transport._log(INFO, 'Authentication successful!') + self.transport._log(INFO, 'Authentication (%s) successful!' % self.auth_method) self.authenticated = True self.transport._auth_trigger() if self.auth_event != None: @@ -346,11 +363,11 @@ class AuthHandler (object): self.transport._log(DEBUG, 'Methods: ' + str(authlist)) self.transport.saved_exception = PartialAuthentication(authlist) elif self.auth_method not in authlist: - self.transport._log(INFO, 'Authentication type not permitted.') + self.transport._log(INFO, 'Authentication type (%s) not permitted.' % self.auth_method) self.transport._log(DEBUG, 'Allowed methods: ' + str(authlist)) self.transport.saved_exception = BadAuthenticationType('Bad authentication type', authlist) else: - self.transport._log(INFO, 'Authentication failed.') + self.transport._log(INFO, 'Authentication (%s) failed.' % self.auth_method) self.authenticated = False self.username = None if self.auth_event != None: @@ -407,4 +424,3 @@ class AuthHandler (object): MSG_USERAUTH_INFO_RESPONSE: _parse_userauth_info_response, } - diff --git a/paramiko/ber.py b/paramiko/ber.py index 6a7823d..9d8ddfa 100644 --- a/paramiko/ber.py +++ b/paramiko/ber.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -16,7 +16,7 @@ # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. -import struct + import util @@ -91,8 +91,9 @@ class BER(object): while True: x = b.decode_next() if x is None: - return out + break out.append(x) + return out decode_sequence = staticmethod(decode_sequence) def encode_tlv(self, ident, val): diff --git a/paramiko/buffered_pipe.py b/paramiko/buffered_pipe.py new file mode 100644 index 0000000..ae3d9d6 --- /dev/null +++ b/paramiko/buffered_pipe.py @@ -0,0 +1,200 @@ +# Copyright (C) 2006-2007 Robey Pointer <robey@lag.net> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Attempt to generalize the "feeder" part of a Channel: an object which can be +read from and closed, but is reading from a buffer fed by another thread. The +read operations are blocking and can have a timeout set. +""" + +import array +import threading +import time + + +class PipeTimeout (IOError): + """ + Indicates that a timeout was reached on a read from a L{BufferedPipe}. + """ + pass + + +class BufferedPipe (object): + """ + A buffer that obeys normal read (with timeout) & close semantics for a + file or socket, but is fed data from another thread. This is used by + L{Channel}. + """ + + def __init__(self): + self._lock = threading.Lock() + self._cv = threading.Condition(self._lock) + self._event = None + self._buffer = array.array('B') + self._closed = False + + def set_event(self, event): + """ + Set an event on this buffer. When data is ready to be read (or the + buffer has been closed), the event will be set. When no data is + ready, the event will be cleared. + + @param event: the event to set/clear + @type event: Event + """ + self._event = event + if len(self._buffer) > 0: + event.set() + else: + event.clear() + + def feed(self, data): + """ + Feed new data into this pipe. This method is assumed to be called + from a separate thread, so synchronization is done. + + @param data: the data to add + @type data: str + """ + self._lock.acquire() + try: + if self._event is not None: + self._event.set() + self._buffer.fromstring(data) + self._cv.notifyAll() + finally: + self._lock.release() + + def read_ready(self): + """ + Returns true if data is buffered and ready to be read from this + feeder. A C{False} result does not mean that the feeder has closed; + it means you may need to wait before more data arrives. + + @return: C{True} if a L{read} call would immediately return at least + one byte; C{False} otherwise. + @rtype: bool + """ + self._lock.acquire() + try: + if len(self._buffer) == 0: + return False + return True + finally: + self._lock.release() + + def read(self, nbytes, timeout=None): + """ + Read data from the pipe. The return value is a string representing + the data received. The maximum amount of data to be received at once + is specified by C{nbytes}. If a string of length zero is returned, + the pipe has been closed. + + The optional C{timeout} argument can be a nonnegative float expressing + seconds, or C{None} for no timeout. If a float is given, a + C{PipeTimeout} will be raised if the timeout period value has + elapsed before any data arrives. + + @param nbytes: maximum number of bytes to read + @type nbytes: int + @param timeout: maximum seconds to wait (or C{None}, the default, to + wait forever) + @type timeout: float + @return: data + @rtype: str + + @raise PipeTimeout: if a timeout was specified and no data was ready + before that timeout + """ + out = '' + self._lock.acquire() + try: + if len(self._buffer) == 0: + if self._closed: + return out + # should we block? + if timeout == 0.0: + raise PipeTimeout() + # loop here in case we get woken up but a different thread has + # grabbed everything in the buffer. + while (len(self._buffer) == 0) and not self._closed: + then = time.time() + self._cv.wait(timeout) + if timeout is not None: + timeout -= time.time() - then + if timeout <= 0.0: + raise PipeTimeout() + + # something's in the buffer and we have the lock! + if len(self._buffer) <= nbytes: + out = self._buffer.tostring() + del self._buffer[:] + if (self._event is not None) and not self._closed: + self._event.clear() + else: + out = self._buffer[:nbytes].tostring() + del self._buffer[:nbytes] + finally: + self._lock.release() + + return out + + def empty(self): + """ + Clear out the buffer and return all data that was in it. + + @return: any data that was in the buffer prior to clearing it out + @rtype: str + """ + self._lock.acquire() + try: + out = self._buffer.tostring() + del self._buffer[:] + if (self._event is not None) and not self._closed: + self._event.clear() + return out + finally: + self._lock.release() + + def close(self): + """ + Close this pipe object. Future calls to L{read} after the buffer + has been emptied will return immediately with an empty string. + """ + self._lock.acquire() + try: + self._closed = True + self._cv.notifyAll() + if self._event is not None: + self._event.set() + finally: + self._lock.release() + + def __len__(self): + """ + Return the number of bytes buffered. + + @return: number of bytes bufferes + @rtype: int + """ + self._lock.acquire() + try: + return len(self._buffer) + finally: + self._lock.release() + diff --git a/paramiko/channel.py b/paramiko/channel.py index 8a00233..910a03c 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -20,6 +20,7 @@ Abstraction for an SSH2 channel. """ +import binascii import sys import time import threading @@ -31,9 +32,14 @@ from paramiko import util from paramiko.message import Message from paramiko.ssh_exception import SSHException from paramiko.file import BufferedFile +from paramiko.buffered_pipe import BufferedPipe, PipeTimeout from paramiko import pipe +# lower bound on the max packet size we'll accept from the remote host +MIN_PACKET_SIZE = 1024 + + class Channel (object): """ A secure tunnel across an SSH L{Transport}. A Channel is meant to behave @@ -49,9 +55,6 @@ class Channel (object): is exactly like a normal network socket, so it shouldn't be too surprising. """ - # lower bound on the max packet size we'll accept from the remote host - MIN_PACKET_SIZE = 1024 - def __init__(self, chanid): """ Create a new channel. The channel is not associated with any @@ -69,14 +72,12 @@ class Channel (object): self.active = False self.eof_received = 0 self.eof_sent = 0 - self.in_buffer = '' - self.in_stderr_buffer = '' + self.in_buffer = BufferedPipe() + self.in_stderr_buffer = BufferedPipe() self.timeout = None self.closed = False self.ultra_debug = False self.lock = threading.Lock() - self.in_buffer_cv = threading.Condition(self.lock) - self.in_stderr_buffer_cv = threading.Condition(self.lock) self.out_buffer_cv = threading.Condition(self.lock) self.in_window_size = 0 self.out_window_size = 0 @@ -85,15 +86,19 @@ class Channel (object): self.in_window_threshold = 0 self.in_window_sofar = 0 self.status_event = threading.Event() - self.name = str(chanid) - self.logger = util.get_logger('paramiko.chan.' + str(chanid)) - self.pipe = None + self._name = str(chanid) + self.logger = util.get_logger('paramiko.transport') + self._pipe = None self.event = threading.Event() self.combine_stderr = False self.exit_status = -1 + self.origin_addr = None def __del__(self): - self.close() + try: + self.close() + except: + pass def __repr__(self): """ @@ -124,14 +129,15 @@ class Channel (object): It isn't necessary (or desirable) to call this method if you're going to exectue a single command with L{exec_command}. - @param term: the terminal type to emulate (for example, C{'vt100'}). + @param term: the terminal type to emulate (for example, C{'vt100'}) @type term: str @param width: width (in characters) of the terminal screen @type width: int @param height: height (in characters) of the terminal screen @type height: int - @return: C{True} if the operation succeeded; C{False} if not. - @rtype: bool + + @raise SSHException: if the request was rejected or the channel was + closed """ if self.closed or self.eof_received or self.eof_sent or not self.active: raise SSHException('Channel is not open') @@ -148,12 +154,7 @@ class Channel (object): m.add_string('') self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + self._wait_for_event() def invoke_shell(self): """ @@ -168,8 +169,8 @@ class Channel (object): When the shell exits, the channel will be closed and can't be reused. You must open a new channel if you wish to open another shell. - @return: C{True} if the operation succeeded; C{False} if not. - @rtype: bool + @raise SSHException: if the request was rejected or the channel was + closed """ if self.closed or self.eof_received or self.eof_sent or not self.active: raise SSHException('Channel is not open') @@ -180,12 +181,7 @@ class Channel (object): m.add_boolean(1) self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + self._wait_for_event() def exec_command(self, command): """ @@ -199,8 +195,9 @@ class Channel (object): @param command: a shell command to execute. @type command: str - @return: C{True} if the operation succeeded; C{False} if not. - @rtype: bool + + @raise SSHException: if the request was rejected or the channel was + closed """ if self.closed or self.eof_received or self.eof_sent or not self.active: raise SSHException('Channel is not open') @@ -208,16 +205,11 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_int(self.remote_chanid) m.add_string('exec') - m.add_boolean(1) + m.add_boolean(True) m.add_string(command) self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + self._wait_for_event() def invoke_subsystem(self, subsystem): """ @@ -230,8 +222,9 @@ class Channel (object): @param subsystem: name of the subsystem being requested. @type subsystem: str - @return: C{True} if the operation succeeded; C{False} if not. - @rtype: bool + + @raise SSHException: if the request was rejected or the channel was + closed """ if self.closed or self.eof_received or self.eof_sent or not self.active: raise SSHException('Channel is not open') @@ -239,16 +232,11 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_int(self.remote_chanid) m.add_string('subsystem') - m.add_boolean(1) + m.add_boolean(True) m.add_string(subsystem) self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + self._wait_for_event() def resize_pty(self, width=80, height=24): """ @@ -259,8 +247,9 @@ class Channel (object): @type width: int @param height: new height (in characters) of the terminal screen @type height: int - @return: C{True} if the operation succeeded; C{False} if not. - @rtype: bool + + @raise SSHException: if the request was rejected or the channel was + closed """ if self.closed or self.eof_received or self.eof_sent or not self.active: raise SSHException('Channel is not open') @@ -268,19 +257,27 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_int(self.remote_chanid) m.add_string('window-change') - m.add_boolean(1) + m.add_boolean(True) m.add_int(width) m.add_int(height) m.add_int(0).add_int(0) self.event.clear() self.transport._send_user_message(m) - while True: - self.event.wait(0.1) - if self.closed: - return False - if self.event.isSet(): - return True + self._wait_for_event() + def exit_status_ready(self): + """ + Return true if the remote process has exited and returned an exit + status. You may use this to poll the process status if you don't + want to block in L{recv_exit_status}. Note that the server may not + return an exit status in some cases (like bad servers). + + @return: True if L{recv_exit_status} will return immediately + @rtype: bool + @since: 1.7.3 + """ + return self.closed or self.status_event.isSet() + def recv_exit_status(self): """ Return the exit status from the process on the server. This is @@ -296,8 +293,9 @@ class Channel (object): """ while True: if self.closed or self.status_event.isSet(): - return self.exit_status + break self.status_event.wait(0.1) + return self.exit_status def send_exit_status(self, status): """ @@ -317,10 +315,73 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_int(self.remote_chanid) m.add_string('exit-status') - m.add_boolean(0) + m.add_boolean(False) m.add_int(status) self.transport._send_user_message(m) + + def request_x11(self, screen_number=0, auth_protocol=None, auth_cookie=None, + single_connection=False, handler=None): + """ + Request an x11 session on this channel. If the server allows it, + further x11 requests can be made from the server to the client, + when an x11 application is run in a shell session. + + From RFC4254:: + + It is RECOMMENDED that the 'x11 authentication cookie' that is + sent be a fake, random cookie, and that the cookie be checked and + replaced by the real cookie when a connection request is received. + If you omit the auth_cookie, a new secure random 128-bit value will be + generated, used, and returned. You will need to use this value to + verify incoming x11 requests and replace them with the actual local + x11 cookie (which requires some knoweldge of the x11 protocol). + + If a handler is passed in, the handler is called from another thread + whenever a new x11 connection arrives. The default handler queues up + incoming x11 connections, which may be retrieved using + L{Transport.accept}. The handler's calling signature is:: + + handler(channel: Channel, (address: str, port: int)) + + @param screen_number: the x11 screen number (0, 10, etc) + @type screen_number: int + @param auth_protocol: the name of the X11 authentication method used; + if none is given, C{"MIT-MAGIC-COOKIE-1"} is used + @type auth_protocol: str + @param auth_cookie: hexadecimal string containing the x11 auth cookie; + if none is given, a secure random 128-bit value is generated + @type auth_cookie: str + @param single_connection: if True, only a single x11 connection will be + forwarded (by default, any number of x11 connections can arrive + over this session) + @type single_connection: bool + @param handler: an optional handler to use for incoming X11 connections + @type handler: function + @return: the auth_cookie used + """ + if self.closed or self.eof_received or self.eof_sent or not self.active: + raise SSHException('Channel is not open') + if auth_protocol is None: + auth_protocol = 'MIT-MAGIC-COOKIE-1' + if auth_cookie is None: + auth_cookie = binascii.hexlify(self.transport.randpool.get_bytes(16)) + + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('x11-req') + m.add_boolean(True) + m.add_boolean(single_connection) + m.add_string(auth_protocol) + m.add_string(auth_cookie) + m.add_int(screen_number) + self.event.clear() + self.transport._send_user_message(m) + self._wait_for_event() + self.transport._set_x11_handler(handler) + return auth_cookie + def get_transport(self): """ Return the L{Transport} associated with this channel. @@ -333,14 +394,13 @@ class Channel (object): def set_name(self, name): """ Set a name for this channel. Currently it's only used to set the name - of the log level used for debugging. The name can be fetched with the + of the channel in logfile entries. The name can be fetched with the L{get_name} method. - @param name: new channel name. + @param name: new channel name @type name: str """ - self.name = name - self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self.name) + self._name = name def get_name(self): """ @@ -349,7 +409,7 @@ class Channel (object): @return: the name of this channel. @rtype: str """ - return self.name + return self._name def get_id(self): """ @@ -360,8 +420,6 @@ class Channel (object): @return: the ID of this channel. @rtype: int - - @since: ivysaur """ return self.chanid @@ -394,8 +452,7 @@ class Channel (object): self.combine_stderr = combine if combine and not old: # copy old stderr buffer into primary buffer - data = self.in_stderr_buffer - self.in_stderr_buffer = '' + data = self.in_stderr_buffer.empty() finally: self.lock.release() if len(data) > 0: @@ -419,7 +476,7 @@ class Channel (object): C{chan.settimeout(None)} is equivalent to C{chan.setblocking(1)}. @param timeout: seconds to wait for a pending read/write operation - before raising C{socket.timeout}, or C{None} for no timeout. + before raising C{socket.timeout}, or C{None} for no timeout. @type timeout: float """ self.timeout = timeout @@ -439,17 +496,19 @@ class Channel (object): """ Set blocking or non-blocking mode of the channel: if C{blocking} is 0, the channel is set to non-blocking mode; otherwise it's set to blocking - mode. Initially all channels are in blocking mode. + mode. Initially all channels are in blocking mode. In non-blocking mode, if a L{recv} call doesn't find any data, or if a L{send} call can't immediately dispose of the data, an error exception - is raised. In blocking mode, the calls block until they can proceed. + is raised. In blocking mode, the calls block until they can proceed. An + EOF condition is considered "immediate data" for L{recv}, so if the + channel is closed in the read direction, it will never block. C{chan.setblocking(0)} is equivalent to C{chan.settimeout(0)}; C{chan.setblocking(1)} is equivalent to C{chan.settimeout(None)}. @param blocking: 0 to set non-blocking mode; non-0 to set blocking - mode. + mode. @type blocking: int """ if blocking: @@ -457,6 +516,18 @@ class Channel (object): else: self.settimeout(0.0) + def getpeername(self): + """ + Return the address of the remote side of this Channel, if possible. + This is just a wrapper around C{'getpeername'} on the Transport, used + to provide enough of a socket-like interface to allow asyncore to work. + (asyncore likes to call C{'getpeername'}.) + + @return: the address if the remote host, if known + @rtype: tuple(str, int) + """ + return self.transport.getpeername() + def close(self): """ Close the channel. All future read/write operations on the channel @@ -466,15 +537,17 @@ class Channel (object): """ self.lock.acquire() try: + # only close the pipe when the user explicitly closes the channel. + # otherwise they will get unpleasant surprises. (and do it before + # checking self.closed, since the remote host may have already + # closed the connection.) + if self._pipe is not None: + self._pipe.close() + self._pipe = None + if not self.active or self.closed: return msgs = self._close_internal() - - # only close the pipe when the user explicitly closes the channel. - # otherwise they will get unpleasant surprises. - if self.pipe is not None: - self.pipe.close() - self.pipe = None finally: self.lock.release() for m in msgs: @@ -491,13 +564,7 @@ class Channel (object): return at least one byte; C{False} otherwise. @rtype: boolean """ - self.lock.acquire() - try: - if len(self.in_buffer) == 0: - return False - return True - finally: - self.lock.release() + return self.in_buffer.read_ready() def recv(self, nbytes): """ @@ -514,38 +581,12 @@ class Channel (object): @raise socket.timeout: if no data is ready before the timeout set by L{settimeout}. """ - out = '' - self.lock.acquire() try: - if len(self.in_buffer) == 0: - if self.closed or self.eof_received: - return out - # should we block? - if self.timeout == 0.0: - raise socket.timeout() - # loop here in case we get woken up but a different thread has grabbed everything in the buffer - timeout = self.timeout - while (len(self.in_buffer) == 0) and not self.closed and not self.eof_received: - then = time.time() - self.in_buffer_cv.wait(timeout) - if timeout != None: - timeout -= time.time() - then - if timeout <= 0.0: - raise socket.timeout() - # something in the buffer and we have the lock - if len(self.in_buffer) <= nbytes: - out = self.in_buffer - self.in_buffer = '' - if self.pipe is not None: - # clear the pipe, since no more data is buffered - self.pipe.clear() - else: - out = self.in_buffer[:nbytes] - self.in_buffer = self.in_buffer[nbytes:] - ack = self._check_add_window(len(out)) - finally: - self.lock.release() + out = self.in_buffer.read(nbytes, self.timeout) + except PipeTimeout, e: + raise socket.timeout() + ack = self._check_add_window(len(out)) # no need to hold the channel lock when sending this if ack > 0: m = Message() @@ -569,13 +610,7 @@ class Channel (object): @since: 1.1 """ - self.lock.acquire() - try: - if len(self.in_stderr_buffer) == 0: - return False - return True - finally: - self.lock.release() + return self.in_stderr_buffer.read_ready() def recv_stderr(self, nbytes): """ @@ -596,36 +631,43 @@ class Channel (object): @since: 1.1 """ - out = '' + try: + out = self.in_stderr_buffer.read(nbytes, self.timeout) + except PipeTimeout, e: + raise socket.timeout() + + ack = self._check_add_window(len(out)) + # no need to hold the channel lock when sending this + if ack > 0: + m = Message() + m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) + m.add_int(self.remote_chanid) + m.add_int(ack) + self.transport._send_user_message(m) + + return out + + def send_ready(self): + """ + Returns true if data can be written to this channel without blocking. + This means the channel is either closed (so any write attempt would + return immediately) or there is at least one byte of space in the + outbound buffer. If there is at least one byte of space in the + outbound buffer, a L{send} call will succeed immediately and return + the number of bytes actually written. + + @return: C{True} if a L{send} call on this channel would immediately + succeed or fail + @rtype: boolean + """ self.lock.acquire() try: - if len(self.in_stderr_buffer) == 0: - if self.closed or self.eof_received: - return out - # should we block? - if self.timeout == 0.0: - raise socket.timeout() - # loop here in case we get woken up but a different thread has grabbed everything in the buffer - timeout = self.timeout - while (len(self.in_stderr_buffer) == 0) and not self.closed and not self.eof_received: - then = time.time() - self.in_stderr_buffer_cv.wait(timeout) - if timeout != None: - timeout -= time.time() - then - if timeout <= 0.0: - raise socket.timeout() - # something in the buffer and we have the lock - if len(self.in_stderr_buffer) <= nbytes: - out = self.in_stderr_buffer - self.in_stderr_buffer = '' - else: - out = self.in_stderr_buffer[:nbytes] - self.in_stderr_buffer = self.in_stderr_buffer[nbytes:] - self._check_add_window(len(out)) + if self.closed or self.eof_sent: + return True + return self.out_window_size > 0 finally: self.lock.release() - return out - + def send(self, s): """ Send data to the channel. Returns the number of bytes sent, or 0 if @@ -634,9 +676,9 @@ class Channel (object): transmitted, the application needs to attempt delivery of the remaining data. - @param s: data to send. + @param s: data to send @type s: str - @return: number of bytes actually sent. + @return: number of bytes actually sent @rtype: int @raise socket.timeout: if no data could be sent before the timeout set @@ -653,9 +695,11 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_DATA)) m.add_int(self.remote_chanid) m.add_string(s[:size]) - self.transport._send_user_message(m) finally: self.lock.release() + # Note: We release self.lock before calling _send_user_message. + # Otherwise, we can deadlock during re-keying. + self.transport._send_user_message(m) return size def send_stderr(self, s): @@ -689,9 +733,11 @@ class Channel (object): m.add_int(self.remote_chanid) m.add_int(1) m.add_string(s[:size]) - self.transport._send_user_message(m) finally: self.lock.release() + # Note: We release self.lock before calling _send_user_message. + # Otherwise, we can deadlock during re-keying. + self.transport._send_user_message(m) return size def sendall(self, s): @@ -776,14 +822,14 @@ class Channel (object): def fileno(self): """ Returns an OS-level file descriptor which can be used for polling, but - but I{not} for reading or writing). This is primaily to allow python's + but I{not} for reading or writing. This is primaily to allow python's C{select} module to work. The first time C{fileno} is called on a channel, a pipe is created to simulate real OS-level file descriptor (FD) behavior. Because of this, two OS-level FDs are created, which will use up FDs faster than normal. - You won't notice this effect unless you open hundreds or thousands of - channels simultaneously, but it's still notable. + (You won't notice this effect unless you have hundreds of channels + open at the same time.) @return: an OS-level file descriptor @rtype: int @@ -793,13 +839,14 @@ class Channel (object): """ self.lock.acquire() try: - if self.pipe is not None: - return self.pipe.fileno() + if self._pipe is not None: + return self._pipe.fileno() # create the pipe and feed in any existing data - self.pipe = pipe.make_pipe() - if len(self.in_buffer) > 0: - self.pipe.set() - return self.pipe.fileno() + self._pipe = pipe.make_pipe() + p1, p2 = pipe.make_or_pipe(self._pipe) + self.in_buffer.set_event(p1) + self.in_stderr_buffer.set_event(p2) + return self._pipe.fileno() finally: self.lock.release() @@ -856,7 +903,7 @@ class Channel (object): def _set_transport(self, transport): self.transport = transport - self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self.name) + self.logger = util.get_logger(self.transport.get_log_channel()) def _set_window(self, window_size, max_packet_size): self.in_window_size = window_size @@ -869,7 +916,7 @@ class Channel (object): def _set_remote_channel(self, chanid, window_size, max_packet_size): self.remote_chanid = chanid self.out_window_size = window_size - self.out_max_packet_size = max(max_packet_size, self.MIN_PACKET_SIZE) + self.out_max_packet_size = max(max_packet_size, MIN_PACKET_SIZE) self.active = 1 self._log(DEBUG, 'Max packet out: %d bytes' % max_packet_size) @@ -894,16 +941,7 @@ class Channel (object): s = m else: s = m.get_string() - self.lock.acquire() - try: - if self.ultra_debug: - self._log(DEBUG, 'fed %d bytes' % len(s)) - if self.pipe is not None: - self.pipe.set() - self.in_buffer += s - self.in_buffer_cv.notifyAll() - finally: - self.lock.release() + self.in_buffer.feed(s) def _feed_extended(self, m): code = m.get_int() @@ -912,15 +950,9 @@ class Channel (object): self._log(ERROR, 'unknown extended_data type %d; discarding' % code) return if self.combine_stderr: - return self._feed(s) - self.lock.acquire() - try: - if self.ultra_debug: - self._log(DEBUG, 'fed %d stderr bytes' % len(s)) - self.in_stderr_buffer += s - self.in_stderr_buffer_cv.notifyAll() - finally: - self.lock.release() + self._feed(s) + else: + self.in_stderr_buffer.feed(s) def _window_adjust(self, m): nbytes = m.get_int() @@ -984,6 +1016,16 @@ class Channel (object): else: ok = server.check_channel_window_change_request(self, width, height, pixelwidth, pixelheight) + elif key == 'x11-req': + single_connection = m.get_boolean() + auth_proto = m.get_string() + auth_cookie = m.get_string() + screen_number = m.get_int() + if server is None: + ok = False + else: + ok = server.check_channel_x11_request(self, single_connection, + auth_proto, auth_cookie, screen_number) else: self._log(DEBUG, 'Unhandled channel request "%s"' % key) ok = False @@ -1001,13 +1043,13 @@ class Channel (object): try: if not self.eof_received: self.eof_received = True - self.in_buffer_cv.notifyAll() - self.in_stderr_buffer_cv.notifyAll() - if self.pipe is not None: - self.pipe.set_forever() + self.in_buffer.close() + self.in_stderr_buffer.close() + if self._pipe is not None: + self._pipe.set_forever() finally: self.lock.release() - self._log(DEBUG, 'EOF received') + self._log(DEBUG, 'EOF received (%s)', self._name) def _handle_close(self, m): self.lock.acquire() @@ -1024,17 +1066,29 @@ class Channel (object): ### internals... - def _log(self, level, msg): - self.logger.log(level, msg) + def _log(self, level, msg, *args): + self.logger.log(level, "[chan " + self._name + "] " + msg, *args) + + def _wait_for_event(self): + while True: + self.event.wait(0.1) + if self.event.isSet(): + return + if self.closed: + e = self.transport.get_exception() + if e is None: + e = SSHException('Channel closed.') + raise e + return def _set_closed(self): # you are holding the lock. self.closed = True - self.in_buffer_cv.notifyAll() - self.in_stderr_buffer_cv.notifyAll() + self.in_buffer.close() + self.in_stderr_buffer.close() self.out_buffer_cv.notifyAll() - if self.pipe is not None: - self.pipe.set_forever() + if self._pipe is not None: + self._pipe.set_forever() def _send_eof(self): # you are holding the lock. @@ -1044,7 +1098,7 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_EOF)) m.add_int(self.remote_chanid) self.eof_sent = True - self._log(DEBUG, 'EOF sent') + self._log(DEBUG, 'EOF sent (%s)', self._name) return m def _close_internal(self): @@ -1072,19 +1126,22 @@ class Channel (object): self.lock.release() def _check_add_window(self, n): - # already holding the lock! - if self.closed or self.eof_received or not self.active: - return 0 - if self.ultra_debug: - self._log(DEBUG, 'addwindow %d' % n) - self.in_window_sofar += n - if self.in_window_sofar <= self.in_window_threshold: - return 0 - if self.ultra_debug: - self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar) - out = self.in_window_sofar - self.in_window_sofar = 0 - return out + self.lock.acquire() + try: + if self.closed or self.eof_received or not self.active: + return 0 + if self.ultra_debug: + self._log(DEBUG, 'addwindow %d' % n) + self.in_window_sofar += n + if self.in_window_sofar <= self.in_window_threshold: + return 0 + if self.ultra_debug: + self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar) + out = self.in_window_sofar + self.in_window_sofar = 0 + return out + finally: + self.lock.release() def _wait_for_send_window(self, size): """ @@ -1155,8 +1212,6 @@ class ChannelFile (BufferedFile): def _write(self, data): self.channel.sendall(data) return len(data) - - seek = BufferedFile.seek class ChannelStderrFile (ChannelFile): diff --git a/paramiko/client.py b/paramiko/client.py new file mode 100644 index 0000000..7870ea9 --- /dev/null +++ b/paramiko/client.py @@ -0,0 +1,474 @@ +# Copyright (C) 2006-2007 Robey Pointer <robey@lag.net> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{SSHClient}. +""" + +from binascii import hexlify +import getpass +import os +import socket +import warnings + +from paramiko.agent import Agent +from paramiko.common import * +from paramiko.dsskey import DSSKey +from paramiko.hostkeys import HostKeys +from paramiko.resource import ResourceManager +from paramiko.rsakey import RSAKey +from paramiko.ssh_exception import SSHException, BadHostKeyException +from paramiko.transport import Transport + + +class MissingHostKeyPolicy (object): + """ + Interface for defining the policy that L{SSHClient} should use when the + SSH server's hostname is not in either the system host keys or the + application's keys. Pre-made classes implement policies for automatically + adding the key to the application's L{HostKeys} object (L{AutoAddPolicy}), + and for automatically rejecting the key (L{RejectPolicy}). + + This function may be used to ask the user to verify the key, for example. + """ + + def missing_host_key(self, client, hostname, key): + """ + Called when an L{SSHClient} receives a server key for a server that + isn't in either the system or local L{HostKeys} object. To accept + the key, simply return. To reject, raised an exception (which will + be passed to the calling application). + """ + pass + + +class AutoAddPolicy (MissingHostKeyPolicy): + """ + Policy for automatically adding the hostname and new host key to the + local L{HostKeys} object, and saving it. This is used by L{SSHClient}. + """ + + def missing_host_key(self, client, hostname, key): + client._host_keys.add(hostname, key.get_name(), key) + if client._host_keys_filename is not None: + client.save_host_keys(client._host_keys_filename) + client._log(DEBUG, 'Adding %s host key for %s: %s' % + (key.get_name(), hostname, hexlify(key.get_fingerprint()))) + + +class RejectPolicy (MissingHostKeyPolicy): + """ + Policy for automatically rejecting the unknown hostname & key. This is + used by L{SSHClient}. + """ + + def missing_host_key(self, client, hostname, key): + client._log(DEBUG, 'Rejecting %s host key for %s: %s' % + (key.get_name(), hostname, hexlify(key.get_fingerprint()))) + raise SSHException('Unknown server %s' % hostname) + + +class WarningPolicy (MissingHostKeyPolicy): + """ + Policy for logging a python-style warning for an unknown host key, but + accepting it. This is used by L{SSHClient}. + """ + def missing_host_key(self, client, hostname, key): + warnings.warn('Unknown %s host key for %s: %s' % + (key.get_name(), hostname, hexlify(key.get_fingerprint()))) + + +class SSHClient (object): + """ + A high-level representation of a session with an SSH server. This class + wraps L{Transport}, L{Channel}, and L{SFTPClient} to take care of most + aspects of authenticating and opening channels. A typical use case is:: + + client = SSHClient() + client.load_system_host_keys() + client.connect('ssh.example.com') + stdin, stdout, stderr = client.exec_command('ls -l') + + You may pass in explicit overrides for authentication and server host key + checking. The default mechanism is to try to use local key files or an + SSH agent (if one is running). + + @since: 1.6 + """ + + def __init__(self): + """ + Create a new SSHClient. + """ + self._system_host_keys = HostKeys() + self._host_keys = HostKeys() + self._host_keys_filename = None + self._log_channel = None + self._policy = RejectPolicy() + self._transport = None + + def load_system_host_keys(self, filename=None): + """ + Load host keys from a system (read-only) file. Host keys read with + this method will not be saved back by L{save_host_keys}. + + This method can be called multiple times. Each new set of host keys + will be merged with the existing set (new replacing old if there are + conflicts). + + If C{filename} is left as C{None}, an attempt will be made to read + keys from the user's local "known hosts" file, as used by OpenSSH, + and no exception will be raised if the file can't be read. This is + probably only useful on posix. + + @param filename: the filename to read, or C{None} + @type filename: str + + @raise IOError: if a filename was provided and the file could not be + read + """ + if filename is None: + # try the user's .ssh key file, and mask exceptions + filename = os.path.expanduser('~/.ssh/known_hosts') + try: + self._system_host_keys.load(filename) + except IOError: + pass + return + self._system_host_keys.load(filename) + + def load_host_keys(self, filename): + """ + Load host keys from a local host-key file. Host keys read with this + method will be checked I{after} keys loaded via L{load_system_host_keys}, + but will be saved back by L{save_host_keys} (so they can be modified). + The missing host key policy L{AutoAddPolicy} adds keys to this set and + saves them, when connecting to a previously-unknown server. + + This method can be called multiple times. Each new set of host keys + will be merged with the existing set (new replacing old if there are + conflicts). When automatically saving, the last hostname is used. + + @param filename: the filename to read + @type filename: str + + @raise IOError: if the filename could not be read + """ + self._host_keys_filename = filename + self._host_keys.load(filename) + + def save_host_keys(self, filename): + """ + Save the host keys back to a file. Only the host keys loaded with + L{load_host_keys} (plus any added directly) will be saved -- not any + host keys loaded with L{load_system_host_keys}. + + @param filename: the filename to save to + @type filename: str + + @raise IOError: if the file could not be written + """ + f = open(filename, 'w') + f.write('# SSH host keys collected by paramiko\n') + for hostname, keys in self._host_keys.iteritems(): + for keytype, key in keys.iteritems(): + f.write('%s %s %s\n' % (hostname, keytype, key.get_base64())) + f.close() + + def get_host_keys(self): + """ + Get the local L{HostKeys} object. This can be used to examine the + local host keys or change them. + + @return: the local host keys + @rtype: L{HostKeys} + """ + return self._host_keys + + def set_log_channel(self, name): + """ + Set the channel for logging. The default is C{"paramiko.transport"} + but it can be set to anything you want. + + @param name: new channel name for logging + @type name: str + """ + self._log_channel = name + + def set_missing_host_key_policy(self, policy): + """ + Set the policy to use when connecting to a server that doesn't have a + host key in either the system or local L{HostKeys} objects. The + default policy is to reject all unknown servers (using L{RejectPolicy}). + You may substitute L{AutoAddPolicy} or write your own policy class. + + @param policy: the policy to use when receiving a host key from a + previously-unknown server + @type policy: L{MissingHostKeyPolicy} + """ + self._policy = policy + + def connect(self, hostname, port=22, username=None, password=None, pkey=None, + key_filename=None, timeout=None, allow_agent=True, look_for_keys=True): + """ + Connect to an SSH server and authenticate to it. The server's host key + is checked against the system host keys (see L{load_system_host_keys}) + and any local host keys (L{load_host_keys}). If the server's hostname + is not found in either set of host keys, the missing host key policy + is used (see L{set_missing_host_key_policy}). The default policy is + to reject the key and raise an L{SSHException}. + + Authentication is attempted in the following order of priority: + + - The C{pkey} or C{key_filename} passed in (if any) + - Any key we can find through an SSH agent + - Any "id_rsa" or "id_dsa" key discoverable in C{~/.ssh/} + - Plain username/password auth, if a password was given + + If a private key requires a password to unlock it, and a password is + passed in, that password will be used to attempt to unlock the key. + + @param hostname: the server to connect to + @type hostname: str + @param port: the server port to connect to + @type port: int + @param username: the username to authenticate as (defaults to the + current local username) + @type username: str + @param password: a password to use for authentication or for unlocking + a private key + @type password: str + @param pkey: an optional private key to use for authentication + @type pkey: L{PKey} + @param key_filename: the filename, or list of filenames, of optional + private key(s) to try for authentication + @type key_filename: str or list(str) + @param timeout: an optional timeout (in seconds) for the TCP connect + @type timeout: float + @param allow_agent: set to False to disable connecting to the SSH agent + @type allow_agent: bool + @param look_for_keys: set to False to disable searching for discoverable + private key files in C{~/.ssh/} + @type look_for_keys: bool + + @raise BadHostKeyException: if the server's host key could not be + verified + @raise AuthenticationException: if authentication failed + @raise SSHException: if there was any other error connecting or + establishing an SSH session + @raise socket.error: if a socket error occurred while connecting + """ + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if timeout is not None: + try: + sock.settimeout(timeout) + except: + pass + + sock.connect((hostname, port)) + t = self._transport = Transport(sock) + + if self._log_channel is not None: + t.set_log_channel(self._log_channel) + t.start_client() + ResourceManager.register(self, t) + + server_key = t.get_remote_server_key() + keytype = server_key.get_name() + + our_server_key = self._system_host_keys.get(hostname, {}).get(keytype, None) + if our_server_key is None: + our_server_key = self._host_keys.get(hostname, {}).get(keytype, None) + if our_server_key is None: + # will raise exception if the key is rejected; let that fall out + self._policy.missing_host_key(self, hostname, server_key) + # if the callback returns, assume the key is ok + our_server_key = server_key + + if server_key != our_server_key: + raise BadHostKeyException(hostname, server_key, our_server_key) + + if username is None: + username = getpass.getuser() + + if key_filename is None: + key_filenames = [] + elif isinstance(key_filename, (str, unicode)): + key_filenames = [ key_filename ] + else: + key_filenames = key_filename + self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys) + + def close(self): + """ + Close this SSHClient and its underlying L{Transport}. + """ + if self._transport is None: + return + self._transport.close() + self._transport = None + + def exec_command(self, command, bufsize=-1): + """ + Execute a command on the SSH server. A new L{Channel} is opened and + the requested command is executed. The command's input and output + streams are returned as python C{file}-like objects representing + stdin, stdout, and stderr. + + @param command: the command to execute + @type command: str + @param bufsize: interpreted the same way as by the built-in C{file()} function in python + @type bufsize: int + @return: the stdin, stdout, and stderr of the executing command + @rtype: tuple(L{ChannelFile}, L{ChannelFile}, L{ChannelFile}) + + @raise SSHException: if the server fails to execute the command + """ + chan = self._transport.open_session() + chan.exec_command(command) + stdin = chan.makefile('wb', bufsize) + stdout = chan.makefile('rb', bufsize) + stderr = chan.makefile_stderr('rb', bufsize) + return stdin, stdout, stderr + + def invoke_shell(self, term='vt100', width=80, height=24): + """ + Start an interactive shell session on the SSH server. A new L{Channel} + is opened and connected to a pseudo-terminal using the requested + terminal type and size. + + @param term: the terminal type to emulate (for example, C{"vt100"}) + @type term: str + @param width: the width (in characters) of the terminal window + @type width: int + @param height: the height (in characters) of the terminal window + @type height: int + @return: a new channel connected to the remote shell + @rtype: L{Channel} + + @raise SSHException: if the server fails to invoke a shell + """ + chan = self._transport.open_session() + chan.get_pty(term, width, height) + chan.invoke_shell() + return chan + + def open_sftp(self): + """ + Open an SFTP session on the SSH server. + + @return: a new SFTP session object + @rtype: L{SFTPClient} + """ + return self._transport.open_sftp_client() + + def get_transport(self): + """ + Return the underlying L{Transport} object for this SSH connection. + This can be used to perform lower-level tasks, like opening specific + kinds of channels. + + @return: the Transport for this connection + @rtype: L{Transport} + """ + return self._transport + + def _auth(self, username, password, pkey, key_filenames, allow_agent, look_for_keys): + """ + Try, in order: + + - The key passed in, if one was passed in. + - Any key we can find through an SSH agent (if allowed). + - Any "id_rsa" or "id_dsa" key discoverable in ~/.ssh/ (if allowed). + - Plain username/password auth, if a password was given. + + (The password might be needed to unlock a private key.) + """ + saved_exception = None + + if pkey is not None: + try: + self._log(DEBUG, 'Trying SSH key %s' % hexlify(pkey.get_fingerprint())) + self._transport.auth_publickey(username, pkey) + return + except SSHException, e: + saved_exception = e + + for key_filename in key_filenames: + for pkey_class in (RSAKey, DSSKey): + try: + key = pkey_class.from_private_key_file(key_filename, password) + self._log(DEBUG, 'Trying key %s from %s' % (hexlify(key.get_fingerprint()), key_filename)) + self._transport.auth_publickey(username, key) + return + except SSHException, e: + saved_exception = e + + if allow_agent: + for key in Agent().get_keys(): + try: + self._log(DEBUG, 'Trying SSH agent key %s' % hexlify(key.get_fingerprint())) + self._transport.auth_publickey(username, key) + return + except SSHException, e: + saved_exception = e + + keyfiles = [] + rsa_key = os.path.expanduser('~/.ssh/id_rsa') + dsa_key = os.path.expanduser('~/.ssh/id_dsa') + if os.path.isfile(rsa_key): + keyfiles.append((RSAKey, rsa_key)) + if os.path.isfile(dsa_key): + keyfiles.append((DSSKey, dsa_key)) + # look in ~/ssh/ for windows users: + rsa_key = os.path.expanduser('~/ssh/id_rsa') + dsa_key = os.path.expanduser('~/ssh/id_dsa') + if os.path.isfile(rsa_key): + keyfiles.append((RSAKey, rsa_key)) + if os.path.isfile(dsa_key): + keyfiles.append((DSSKey, dsa_key)) + + if not look_for_keys: + keyfiles = [] + + for pkey_class, filename in keyfiles: + try: + key = pkey_class.from_private_key_file(filename, password) + self._log(DEBUG, 'Trying discovered key %s in %s' % (hexlify(key.get_fingerprint()), filename)) + self._transport.auth_publickey(username, key) + return + except SSHException, e: + saved_exception = e + except IOError, e: + saved_exception = e + + if password is not None: + try: + self._transport.auth_password(username, password) + return + except SSHException, e: + saved_exception = e + + # if we got an auth-failed exception earlier, re-raise it + if saved_exception is not None: + raise saved_exception + raise SSHException('No authentication methods available') + + def _log(self, level, msg): + self._transport._log(level, msg) + diff --git a/paramiko/common.py b/paramiko/common.py index c5999e6..f4a4d81 100644 --- a/paramiko/common.py +++ b/paramiko/common.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -21,7 +21,7 @@ Common constants and global variables. """ MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_SERVICE_REQUEST, \ - MSG_SERVICE_ACCEPT = range(1, 7) + MSG_SERVICE_ACCEPT = range(1, 7) MSG_KEXINIT, MSG_NEWKEYS = range(20, 22) MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \ MSG_USERAUTH_BANNER = range(50, 54) @@ -29,9 +29,9 @@ MSG_USERAUTH_PK_OK = 60 MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62) MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83) MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \ - MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \ - MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ - MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) + MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \ + MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ + MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) # for debugging: @@ -95,21 +95,10 @@ CONNECTION_FAILED_CODE = { DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \ DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14 - -from Crypto.Util.randpool import PersistentRandomPool, RandomPool +from rng import StrongLockingRandomPool # keep a crypto-strong PRNG nearby -try: - randpool = PersistentRandomPool(os.path.join(os.path.expanduser('~'), '/.randpool')) -except: - # the above will likely fail on Windows - fall back to non-persistent random pool - randpool = RandomPool() - -try: - randpool.randomize() -except: - # earlier versions of pyCrypto (pre-2.0) don't have randomize() - pass +randpool = StrongLockingRandomPool() import sys if sys.version_info < (2, 3): @@ -129,6 +118,7 @@ else: import logging PY22 = False + DEBUG = logging.DEBUG INFO = logging.INFO WARNING = logging.WARNING diff --git a/paramiko/compress.py b/paramiko/compress.py index bdf4b42..08fffb1 100644 --- a/paramiko/compress.py +++ b/paramiko/compress.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # diff --git a/paramiko/config.py b/paramiko/config.py new file mode 100644 index 0000000..1e3d680 --- /dev/null +++ b/paramiko/config.py @@ -0,0 +1,105 @@ +# Copyright (C) 2006-2007 Robey Pointer <robey@lag.net> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{SSHConfig}. +""" + +import fnmatch + + +class SSHConfig (object): + """ + Representation of config information as stored in the format used by + OpenSSH. Queries can be made via L{lookup}. The format is described in + OpenSSH's C{ssh_config} man page. This class is provided primarily as a + convenience to posix users (since the OpenSSH format is a de-facto + standard on posix) but should work fine on Windows too. + + @since: 1.6 + """ + + def __init__(self): + """ + Create a new OpenSSH config object. + """ + self._config = [ { 'host': '*' } ] + + def parse(self, file_obj): + """ + Read an OpenSSH config from the given file object. + + @param file_obj: a file-like object to read the config file from + @type file_obj: file + """ + config = self._config[0] + for line in file_obj: + line = line.rstrip('\n').lstrip() + if (line == '') or (line[0] == '#'): + continue + if '=' in line: + key, value = line.split('=', 1) + key = key.strip().lower() + else: + # find first whitespace, and split there + i = 0 + while (i < len(line)) and not line[i].isspace(): + i += 1 + if i == len(line): + raise Exception('Unparsable line: %r' % line) + key = line[:i].lower() + value = line[i:].lstrip() + + if key == 'host': + # do we have a pre-existing host config to append to? + matches = [c for c in self._config if c['host'] == value] + if len(matches) > 0: + config = matches[0] + else: + config = { 'host': value } + self._config.append(config) + else: + config[key] = value + + def lookup(self, hostname): + """ + Return a dict of config options for a given hostname. + + The host-matching rules of OpenSSH's C{ssh_config} man page are used, + which means that all configuration options from matching host + specifications are merged, with more specific hostmasks taking + precedence. In other words, if C{"Port"} is set under C{"Host *"} + and also C{"Host *.example.com"}, and the lookup is for + C{"ssh.example.com"}, then the port entry for C{"Host *.example.com"} + will win out. + + The keys in the returned dict are all normalized to lowercase (look for + C{"port"}, not C{"Port"}. No other processing is done to the keys or + values. + + @param hostname: the hostname to lookup + @type hostname: str + """ + matches = [x for x in self._config if fnmatch.fnmatch(hostname, x['host'])] + # sort in order of shortest match (usually '*') to longest + matches.sort(lambda x,y: cmp(len(x['host']), len(y['host']))) + ret = {} + for m in matches: + ret.update(m) + del ret['host'] + return ret diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py index 2b31372..9f381d2 100644 --- a/paramiko/dsskey.py +++ b/paramiko/dsskey.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -37,7 +37,15 @@ class DSSKey (PKey): data. """ - def __init__(self, msg=None, data=None, filename=None, password=None, vals=None): + def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None): + self.p = None + self.q = None + self.g = None + self.y = None + self.x = None + if file_obj is not None: + self._from_private_key(file_obj, password) + return if filename is not None: self._from_private_key_file(filename, password) return @@ -81,7 +89,7 @@ class DSSKey (PKey): return self.size def can_sign(self): - return hasattr(self, 'x') + return self.x is not None def sign_ssh_data(self, rpool, data): digest = SHA.new(data).digest() @@ -123,14 +131,22 @@ class DSSKey (PKey): dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q))) return dss.verify(sigM, (sigR, sigS)) - def write_private_key_file(self, filename, password=None): + def _encode_key(self): + if self.x is None: + raise SSHException('Not enough key information') keylist = [ 0, self.p, self.q, self.g, self.y, self.x ] try: b = BER() b.encode(keylist) except BERException: raise SSHException('Unable to create ber encoding of key') - self._write_private_key_file('DSA', filename, str(b), password) + return str(b) + + def write_private_key_file(self, filename, password=None): + self._write_private_key_file('DSA', filename, self._encode_key(), password) + + def write_private_key(self, file_obj, password=None): + self._write_private_key('DSA', file_obj, self._encode_key(), password) def generate(bits=1024, progress_func=None): """ @@ -144,8 +160,6 @@ class DSSKey (PKey): @type progress_func: function @return: new private key @rtype: L{DSSKey} - - @since: fearow """ randpool.stir() dsa = DSA.generate(bits, randpool.get_bytes, progress_func) @@ -159,9 +173,16 @@ class DSSKey (PKey): def _from_private_key_file(self, filename, password): + data = self._read_private_key_file('DSA', filename, password) + self._decode_key(data) + + def _from_private_key(self, file_obj, password): + data = self._read_private_key('DSA', file_obj, password) + self._decode_key(data) + + def _decode_key(self, data): # private key file contains: # DSAPrivateKey = { version = 0, p, q, g, y, x } - data = self._read_private_key_file('DSA', filename, password) try: keylist = BER(data).decode() except BERException, x: diff --git a/paramiko/file.py b/paramiko/file.py index c29e7c4..7db4401 100644 --- a/paramiko/file.py +++ b/paramiko/file.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -23,15 +23,6 @@ BufferedFile. from cStringIO import StringIO -_FLAG_READ = 0x1 -_FLAG_WRITE = 0x2 -_FLAG_APPEND = 0x4 -_FLAG_BINARY = 0x10 -_FLAG_BUFFERED = 0x20 -_FLAG_LINE_BUFFERED = 0x40 -_FLAG_UNIVERSAL_NEWLINE = 0x80 - - class BufferedFile (object): """ Reusable base class to implement python-style file buffering around a @@ -44,7 +35,16 @@ class BufferedFile (object): SEEK_CUR = 1 SEEK_END = 2 + FLAG_READ = 0x1 + FLAG_WRITE = 0x2 + FLAG_APPEND = 0x4 + FLAG_BINARY = 0x10 + FLAG_BUFFERED = 0x20 + FLAG_LINE_BUFFERED = 0x40 + FLAG_UNIVERSAL_NEWLINE = 0x80 + def __init__(self): + self.newlines = None self._flags = 0 self._bufsize = self._DEFAULT_BUFSIZE self._wbuffer = StringIO() @@ -55,6 +55,8 @@ class BufferedFile (object): # realpos - position according the OS # (these may be different because we buffer for line reading) self._pos = self._realpos = 0 + # size only matters for seekable files + self._size = 0 def __del__(self): self.close() @@ -112,16 +114,16 @@ class BufferedFile (object): file first). If the C{size} argument is negative or omitted, read all the remaining data in the file. - @param size: maximum number of bytes to read. + @param size: maximum number of bytes to read @type size: int @return: data read from the file, or an empty string if EOF was - encountered immediately. + encountered immediately @rtype: str """ if self._closed: raise IOError('File is closed') - if not (self._flags & _FLAG_READ): - raise IOError('File not open for reading') + if not (self._flags & self.FLAG_READ): + raise IOError('File is not open for reading') if (size is None) or (size < 0): # go for broke result = self._rbuffer @@ -144,8 +146,11 @@ class BufferedFile (object): self._pos += len(result) return result while len(self._rbuffer) < size: + read_size = size - len(self._rbuffer) + if self._flags & self.FLAG_BUFFERED: + read_size = max(self._bufsize, read_size) try: - new_data = self._read(max(self._bufsize, size - len(self._rbuffer))) + new_data = self._read(read_size) except EOFError: new_data = None if (new_data is None) or (len(new_data) == 0): @@ -178,11 +183,11 @@ class BufferedFile (object): # it's almost silly how complex this function is. if self._closed: raise IOError('File is closed') - if not (self._flags & _FLAG_READ): + if not (self._flags & self.FLAG_READ): raise IOError('File not open for reading') line = self._rbuffer while True: - if self._at_trailing_cr and (self._flags & _FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0): + if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0): # edge case: the newline may be '\r\n' and we may have read # only the first '\r' last time. if line[0] == '\n': @@ -202,8 +207,8 @@ class BufferedFile (object): return line n = size - len(line) else: - n = self._DEFAULT_BUFSIZE - if ('\n' in line) or ((self._flags & _FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)): + n = self._bufsize + if ('\n' in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)): break try: new_data = self._read(n) @@ -217,7 +222,7 @@ class BufferedFile (object): self._realpos += len(new_data) # find the newline pos = line.find('\n') - if self._flags & _FLAG_UNIVERSAL_NEWLINE: + if self._flags & self.FLAG_UNIVERSAL_NEWLINE: rpos = line.find('\r') if (rpos >= 0) and ((rpos < pos) or (pos < 0)): pos = rpos @@ -250,7 +255,7 @@ class BufferedFile (object): """ lines = [] bytes = 0 - while 1: + while True: line = self.readline() if len(line) == 0: break @@ -303,13 +308,13 @@ class BufferedFile (object): """ if self._closed: raise IOError('File is closed') - if not (self._flags & _FLAG_WRITE): + if not (self._flags & self.FLAG_WRITE): raise IOError('File not open for writing') - if not (self._flags & _FLAG_BUFFERED): + if not (self._flags & self.FLAG_BUFFERED): self._write_all(data) return self._wbuffer.write(data) - if self._flags & _FLAG_LINE_BUFFERED: + if self._flags & self.FLAG_LINE_BUFFERED: # only scan the new data for linefeed, to avoid wasting time. last_newline_pos = data.rfind('\n') if last_newline_pos >= 0: @@ -387,26 +392,37 @@ class BufferedFile (object): """ Subclasses call this method to initialize the BufferedFile. """ + # set bufsize in any event, because it's used for readline(). + self._bufsize = self._DEFAULT_BUFSIZE + if bufsize < 0: + # do no buffering by default, because otherwise writes will get + # buffered in a way that will probably confuse people. + bufsize = 0 if bufsize == 1: # apparently, line buffering only affects writes. reads are only # buffered if you call readline (directly or indirectly: iterating # over a file will indirectly call readline). - self._flags |= _FLAG_BUFFERED | _FLAG_LINE_BUFFERED + self._flags |= self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED elif bufsize > 1: self._bufsize = bufsize - self._flags |= _FLAG_BUFFERED + self._flags |= self.FLAG_BUFFERED + self._flags &= ~self.FLAG_LINE_BUFFERED + elif bufsize == 0: + # unbuffered + self._flags &= ~(self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED) + if ('r' in mode) or ('+' in mode): - self._flags |= _FLAG_READ + self._flags |= self.FLAG_READ if ('w' in mode) or ('+' in mode): - self._flags |= _FLAG_WRITE + self._flags |= self.FLAG_WRITE if ('a' in mode): - self._flags |= _FLAG_WRITE | _FLAG_APPEND + self._flags |= self.FLAG_WRITE | self.FLAG_APPEND self._size = self._get_size() self._pos = self._realpos = self._size if ('b' in mode): - self._flags |= _FLAG_BINARY + self._flags |= self.FLAG_BINARY if ('U' in mode): - self._flags |= _FLAG_UNIVERSAL_NEWLINE + self._flags |= self.FLAG_UNIVERSAL_NEWLINE # built-in file objects have this attribute to store which kinds of # line terminations they've seen: # <http://www.python.org/doc/current/lib/built-in-funcs.html> @@ -418,7 +434,7 @@ class BufferedFile (object): while len(data) > 0: count = self._write(data) data = data[count:] - if self._flags & _FLAG_APPEND: + if self._flags & self.FLAG_APPEND: self._size += count self._pos = self._realpos = self._size else: @@ -430,7 +446,7 @@ class BufferedFile (object): # silliness about tracking what kinds of newlines we've seen. # i don't understand why it can be None, a string, or a tuple, instead # of just always being a tuple, but we'll emulate that behavior anyway. - if not (self._flags & _FLAG_UNIVERSAL_NEWLINE): + if not (self._flags & self.FLAG_UNIVERSAL_NEWLINE): return if self.newlines is None: self.newlines = newline diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py new file mode 100644 index 0000000..0c0ac8c --- /dev/null +++ b/paramiko/hostkeys.py @@ -0,0 +1,315 @@ +# Copyright (C) 2006-2007 Robey Pointer <robey@lag.net> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{HostKeys} +""" + +import base64 +from Crypto.Hash import SHA, HMAC +import UserDict + +from paramiko.common import * +from paramiko.dsskey import DSSKey +from paramiko.rsakey import RSAKey + + +class HostKeyEntry: + """ + Representation of a line in an OpenSSH-style "known hosts" file. + """ + + def __init__(self, hostnames=None, key=None): + self.valid = (hostnames is not None) and (key is not None) + self.hostnames = hostnames + self.key = key + + def from_line(cls, line): + """ + Parses the given line of text to find the names for the host, + the type of key, and the key data. The line is expected to be in the + format used by the openssh known_hosts file. + + Lines are expected to not have leading or trailing whitespace. + We don't bother to check for comments or empty lines. All of + that should be taken care of before sending the line to us. + + @param line: a line from an OpenSSH known_hosts file + @type line: str + """ + fields = line.split(' ') + if len(fields) != 3: + # Bad number of fields + return None + + names, keytype, key = fields + names = names.split(',') + + # Decide what kind of key we're looking at and create an object + # to hold it accordingly. + if keytype == 'ssh-rsa': + key = RSAKey(data=base64.decodestring(key)) + elif keytype == 'ssh-dss': + key = DSSKey(data=base64.decodestring(key)) + else: + return None + + return cls(names, key) + from_line = classmethod(from_line) + + def to_line(self): + """ + Returns a string in OpenSSH known_hosts file format, or None if + the object is not in a valid state. A trailing newline is + included. + """ + if self.valid: + return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(), + self.key.get_base64()) + return None + + def __repr__(self): + return '<HostKeyEntry %r: %r>' % (self.hostnames, self.key) + + +class HostKeys (UserDict.DictMixin): + """ + Representation of an openssh-style "known hosts" file. Host keys can be + read from one or more files, and then individual hosts can be looked up to + verify server keys during SSH negotiation. + + A HostKeys object can be treated like a dict; any dict lookup is equivalent + to calling L{lookup}. + + @since: 1.5.3 + """ + + def __init__(self, filename=None): + """ + Create a new HostKeys object, optionally loading keys from an openssh + style host-key file. + + @param filename: filename to load host keys from, or C{None} + @type filename: str + """ + # emulate a dict of { hostname: { keytype: PKey } } + self._entries = [] + if filename is not None: + self.load(filename) + + def add(self, hostname, keytype, key): + """ + Add a host key entry to the table. Any existing entry for a + C{(hostname, keytype)} pair will be replaced. + + @param hostname: the hostname (or IP) to add + @type hostname: str + @param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"}) + @type keytype: str + @param key: the key to add + @type key: L{PKey} + """ + for e in self._entries: + if (hostname in e.hostnames) and (e.key.get_name() == keytype): + e.key = key + return + self._entries.append(HostKeyEntry([hostname], key)) + + def load(self, filename): + """ + Read a file of known SSH host keys, in the format used by openssh. + This type of file unfortunately doesn't exist on Windows, but on + posix, it will usually be stored in + C{os.path.expanduser("~/.ssh/known_hosts")}. + + If this method is called multiple times, the host keys are merged, + not cleared. So multiple calls to C{load} will just call L{add}, + replacing any existing entries and adding new ones. + + @param filename: name of the file to read host keys from + @type filename: str + + @raise IOError: if there was an error reading the file + """ + f = open(filename, 'r') + for line in f: + line = line.strip() + if (len(line) == 0) or (line[0] == '#'): + continue + e = HostKeyEntry.from_line(line) + if e is not None: + self._entries.append(e) + f.close() + + def save(self, filename): + """ + Save host keys into a file, in the format used by openssh. The order of + keys in the file will be preserved when possible (if these keys were + loaded from a file originally). The single exception is that combined + lines will be split into individual key lines, which is arguably a bug. + + @param filename: name of the file to write + @type filename: str + + @raise IOError: if there was an error writing the file + + @since: 1.6.1 + """ + f = open(filename, 'w') + for e in self._entries: + line = e.to_line() + if line: + f.write(line) + f.close() + + def lookup(self, hostname): + """ + Find a hostkey entry for a given hostname or IP. If no entry is found, + C{None} is returned. Otherwise a dictionary of keytype to key is + returned. The keytype will be either C{"ssh-rsa"} or C{"ssh-dss"}. + + @param hostname: the hostname (or IP) to lookup + @type hostname: str + @return: keys associated with this host (or C{None}) + @rtype: dict(str, L{PKey}) + """ + class SubDict (UserDict.DictMixin): + def __init__(self, hostname, entries, hostkeys): + self._hostname = hostname + self._entries = entries + self._hostkeys = hostkeys + + def __getitem__(self, key): + for e in self._entries: + if e.key.get_name() == key: + return e.key + raise KeyError(key) + + def __setitem__(self, key, val): + for e in self._entries: + if e.key is None: + continue + if e.key.get_name() == key: + # replace + e.key = val + break + else: + # add a new one + e = HostKeyEntry([hostname], val) + self._entries.append(e) + self._hostkeys._entries.append(e) + + def keys(self): + return [e.key.get_name() for e in self._entries if e.key is not None] + + entries = [] + for e in self._entries: + for h in e.hostnames: + if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname): + entries.append(e) + if len(entries) == 0: + return None + return SubDict(hostname, entries, self) + + def check(self, hostname, key): + """ + Return True if the given key is associated with the given hostname + in this dictionary. + + @param hostname: hostname (or IP) of the SSH server + @type hostname: str + @param key: the key to check + @type key: L{PKey} + @return: C{True} if the key is associated with the hostname; C{False} + if not + @rtype: bool + """ + k = self.lookup(hostname) + if k is None: + return False + host_key = k.get(key.get_name(), None) + if host_key is None: + return False + return str(host_key) == str(key) + + def clear(self): + """ + Remove all host keys from the dictionary. + """ + self._entries = [] + + def __getitem__(self, key): + ret = self.lookup(key) + if ret is None: + raise KeyError(key) + return ret + + def __setitem__(self, hostname, entry): + # don't use this please. + if len(entry) == 0: + self._entries.append(HostKeyEntry([hostname], None)) + return + for key_type in entry.keys(): + found = False + for e in self._entries: + if (hostname in e.hostnames) and (e.key.get_name() == key_type): + # replace + e.key = entry[key_type] + found = True + if not found: + self._entries.append(HostKeyEntry([hostname], entry[key_type])) + + def keys(self): + # python 2.4 sets would be nice here. + ret = [] + for e in self._entries: + for h in e.hostnames: + if h not in ret: + ret.append(h) + return ret + + def values(self): + ret = [] + for k in self.keys(): + ret.append(self.lookup(k)) + return ret + + def hash_host(hostname, salt=None): + """ + Return a "hashed" form of the hostname, as used by openssh when storing + hashed hostnames in the known_hosts file. + + @param hostname: the hostname to hash + @type hostname: str + @param salt: optional salt to use when hashing (must be 20 bytes long) + @type salt: str + @return: the hashed hostname + @rtype: str + """ + if salt is None: + salt = randpool.get_bytes(SHA.digest_size) + else: + if salt.startswith('|1|'): + salt = salt.split('|')[2] + salt = base64.decodestring(salt) + assert len(salt) == SHA.digest_size + hmac = HMAC.HMAC(salt, hostname, SHA).digest() + hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac)) + return hostkey.replace('\n', '') + hash_host = staticmethod(hash_host) + diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py index 994d76c..63a0c99 100644 --- a/paramiko/kex_gex.py +++ b/paramiko/kex_gex.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -31,7 +31,8 @@ from paramiko.message import Message from paramiko.ssh_exception import SSHException -_MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(31, 35) +_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \ + _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35) class KexGex (object): @@ -43,19 +44,32 @@ class KexGex (object): def __init__(self, transport): self.transport = transport - - def start_kex(self): + self.p = None + self.q = None + self.g = None + self.x = None + self.e = None + self.f = None + self.old_style = False + + def start_kex(self, _test_old_style=False): if self.transport.server_mode: - self.transport._expect_packet(_MSG_KEXDH_GEX_REQUEST) + self.transport._expect_packet(_MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD) return # request a bit range: we accept (min_bits) to (max_bits), but prefer # (preferred_bits). according to the spec, we shouldn't pull the # minimum up above 1024. m = Message() - m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST)) - m.add_int(self.min_bits) - m.add_int(self.preferred_bits) - m.add_int(self.max_bits) + if _test_old_style: + # only used for unit tests: we shouldn't ever send this + m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST_OLD)) + m.add_int(self.preferred_bits) + self.old_style = True + else: + m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST)) + m.add_int(self.min_bits) + m.add_int(self.preferred_bits) + m.add_int(self.max_bits) self.transport._send_message(m) self.transport._expect_packet(_MSG_KEXDH_GEX_GROUP) @@ -68,6 +82,8 @@ class KexGex (object): return self._parse_kexdh_gex_init(m) elif ptype == _MSG_KEXDH_GEX_REPLY: return self._parse_kexdh_gex_reply(m) + elif ptype == _MSG_KEXDH_GEX_REQUEST_OLD: + return self._parse_kexdh_gex_request_old(m) raise SSHException('KexGex asked to handle packet type %d' % ptype) @@ -126,6 +142,28 @@ class KexGex (object): self.transport._send_message(m) self.transport._expect_packet(_MSG_KEXDH_GEX_INIT) + def _parse_kexdh_gex_request_old(self, m): + # same as above, but without min_bits or max_bits (used by older clients like putty) + self.preferred_bits = m.get_int() + # smoosh the user's preferred size into our own limits + if self.preferred_bits > self.max_bits: + self.preferred_bits = self.max_bits + if self.preferred_bits < self.min_bits: + self.preferred_bits = self.min_bits + # generate prime + pack = self.transport._get_modulus_pack() + if pack is None: + raise SSHException('Can\'t do server-side gex with no modulus pack') + self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,)) + self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits) + m = Message() + m.add_byte(chr(_MSG_KEXDH_GEX_GROUP)) + m.add_mpint(self.p) + m.add_mpint(self.g) + self.transport._send_message(m) + self.transport._expect_packet(_MSG_KEXDH_GEX_INIT) + self.old_style = True + def _parse_kexdh_gex_group(self, m): self.p = m.get_mpint() self.g = m.get_mpint() @@ -156,9 +194,11 @@ class KexGex (object): hm.add(self.transport.remote_version, self.transport.local_version, self.transport.remote_kex_init, self.transport.local_kex_init, key) - hm.add_int(self.min_bits) + if not self.old_style: + hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) - hm.add_int(self.max_bits) + if not self.old_style: + hm.add_int(self.max_bits) hm.add_mpint(self.p) hm.add_mpint(self.g) hm.add_mpint(self.e) @@ -189,9 +229,11 @@ class KexGex (object): hm.add(self.transport.local_version, self.transport.remote_version, self.transport.local_kex_init, self.transport.remote_kex_init, host_key) - hm.add_int(self.min_bits) + if not self.old_style: + hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) - hm.add_int(self.max_bits) + if not self.old_style: + hm.add_int(self.max_bits) hm.add_mpint(self.p) hm.add_mpint(self.g) hm.add_mpint(self.e) diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py index a13cf3a..843a6d8 100644 --- a/paramiko/kex_group1.py +++ b/paramiko/kex_group1.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # diff --git a/paramiko/logging22.py b/paramiko/logging22.py index ac11a73..9bf7656 100644 --- a/paramiko/logging22.py +++ b/paramiko/logging22.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # diff --git a/paramiko/message.py b/paramiko/message.py index 1d75a01..1a5151c 100644 --- a/paramiko/message.py +++ b/paramiko/message.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -285,7 +285,7 @@ class Message (object): elif type(i) is list: return self.add_list(i) else: - raise exception('Unknown type') + raise Exception('Unknown type') def add(self, *seq): """ diff --git a/paramiko/packet.py b/paramiko/packet.py index 277d68e..4bde2f7 100644 --- a/paramiko/packet.py +++ b/paramiko/packet.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -20,12 +20,12 @@ Packetizer. """ +import errno import select import socket import struct import threading import time -from Crypto.Hash import HMAC from paramiko.common import * from paramiko import util @@ -33,6 +33,19 @@ from paramiko.ssh_exception import SSHException from paramiko.message import Message +got_r_hmac = False +try: + import r_hmac + got_r_hmac = True +except ImportError: + pass +def compute_hmac(key, message, digest_class): + if got_r_hmac: + return r_hmac.HMAC(key, message, digest_class).digest() + from Crypto.Hash import HMAC + return HMAC.HMAC(key, message, digest_class).digest() + + class NeedRekeyException (Exception): pass @@ -54,6 +67,7 @@ class Packetizer (object): self.__dump_packets = False self.__need_rekey = False self.__init_count = 0 + self.__remainder = '' # used for noticing when to re-key: self.__sent_bytes = 0 @@ -86,13 +100,6 @@ class Packetizer (object): self.__keepalive_last = time.time() self.__keepalive_callback = None - def __del__(self): - # this is not guaranteed to be called, but we should try. - try: - self.__socket.close() - except: - pass - def set_log(self, log): """ Set the python log object to use for logging. @@ -142,6 +149,7 @@ class Packetizer (object): def close(self): self.__closed = True + self.__socket.close() def set_hexdump(self, hexdump): self.__dump_packets = hexdump @@ -186,10 +194,16 @@ class Packetizer (object): @raise EOFError: if the socket was closed before all the bytes could be read """ - if PY22: - return self._py22_read_all(n) out = '' + # handle over-reading from reading the banner line + if len(self.__remainder) > 0: + out = self.__remainder[:n] + self.__remainder = self.__remainder[n:] + n -= len(out) + if PY22: + return self._py22_read_all(n, out) while n > 0: + got_timeout = False try: x = self.__socket.recv(n) if len(x) == 0: @@ -197,6 +211,21 @@ class Packetizer (object): out += x n -= len(x) except socket.timeout: + got_timeout = True + except socket.error, e: + # on Linux, sometimes instead of socket.timeout, we get + # EAGAIN. this is a bug in recent (> 2.6.9) kernels but + # we need to work around it. + if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN): + got_timeout = True + elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR): + # syscall interrupted; try again + pass + elif self.__closed: + raise EOFError() + else: + raise + if got_timeout: if self.__closed: raise EOFError() if check_rekey and (len(out) == 0) and self.__need_rekey: @@ -207,32 +236,44 @@ class Packetizer (object): def write_all(self, out): self.__keepalive_last = time.time() while len(out) > 0: + got_timeout = False try: n = self.__socket.send(out) except socket.timeout: - n = 0 - if self.__closed: + got_timeout = True + except socket.error, e: + if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN): + got_timeout = True + elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR): + # syscall interrupted; try again + pass + else: n = -1 except Exception: # could be: (32, 'Broken pipe') n = -1 + if got_timeout: + n = 0 + if self.__closed: + n = -1 if n < 0: raise EOFError() if n == len(out): - return + break out = out[n:] return def readline(self, timeout): """ - Read a line from the socket. This is done in a fairly inefficient - way, but is only used for initial banner negotiation so it's not worth - optimising. + Read a line from the socket. We assume no data is pending after the + line, so it's okay to attempt large reads. """ buf = '' while not '\n' in buf: buf += self._read_timeout(timeout) - buf = buf[:-1] + n = buf.index('\n') + self.__remainder += buf[n+1:] + buf = buf[:n] if (len(buf) > 0) and (buf[-1] == '\r'): buf = buf[:-1] return buf @@ -242,21 +283,21 @@ class Packetizer (object): Write a block of data using the current cipher, as an SSH block. """ # encrypt this sucka - randpool.stir() data = str(data) cmd = ord(data[0]) if cmd in MSG_NAMES: cmd_name = MSG_NAMES[cmd] else: cmd_name = '$%x' % cmd - self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data))) - if self.__compress_engine_out is not None: - data = self.__compress_engine_out(data) - packet = self._build_packet(data) - if self.__dump_packets: - self._log(DEBUG, util.format_binary(packet, 'OUT: ')) + orig_len = len(data) self.__write_lock.acquire() try: + if self.__compress_engine_out is not None: + data = self.__compress_engine_out(data) + packet = self._build_packet(data) + if self.__dump_packets: + self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, orig_len)) + self._log(DEBUG, util.format_binary(packet, 'OUT: ')) if self.__block_engine_out != None: out = self.__block_engine_out.encrypt(packet) else: @@ -264,12 +305,15 @@ class Packetizer (object): # + mac if self.__block_engine_out != None: payload = struct.pack('>I', self.__sequence_number_out) + packet - out += HMAC.HMAC(self.__mac_key_out, payload, self.__mac_engine_out).digest()[:self.__mac_size_out] + out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out] self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL self.write_all(out) self.__sent_bytes += len(out) self.__sent_packets += 1 + if (self.__sent_packets % 100) == 0: + # stirring the randpool takes 30ms on my ibook!! + randpool.stir() if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \ and not self.__need_rekey: # only ask once for rekeying @@ -310,12 +354,12 @@ class Packetizer (object): if self.__mac_size_in > 0: mac = post_packet[:self.__mac_size_in] mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet - my_mac = HMAC.HMAC(self.__mac_key_in, mac_payload, self.__mac_engine_in).digest()[:self.__mac_size_in] + my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in] if my_mac != mac: raise SSHException('Mismatched MAC') padding = ord(packet[0]) payload = packet[1:packet_size - padding] - randpool.add_event(packet[packet_size - padding]) + randpool.add_event() if self.__dump_packets: self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) @@ -348,7 +392,8 @@ class Packetizer (object): cmd_name = MSG_NAMES[cmd] else: cmd_name = '$%x' % cmd - self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload))) + if self.__dump_packets: + self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload))) return cmd, msg @@ -374,8 +419,7 @@ class Packetizer (object): self.__keepalive_callback() self.__keepalive_last = now - def _py22_read_all(self, n): - out = '' + def _py22_read_all(self, n, out): while n > 0: r, w, e = select.select([self.__socket], [], [], 0.1) if self.__socket not in r: @@ -398,23 +442,24 @@ class Packetizer (object): x = self.__socket.recv(1) if len(x) == 0: raise EOFError() - return x + break if self.__closed: raise EOFError() now = time.time() if now - start >= timeout: raise socket.timeout() + return x def _read_timeout(self, timeout): if PY22: - return self._py22_read_timeout(n) + return self._py22_read_timeout(timeout) start = time.time() while True: try: - x = self.__socket.recv(1) + x = self.__socket.recv(128) if len(x) == 0: raise EOFError() - return x + break except socket.timeout: pass if self.__closed: @@ -422,6 +467,7 @@ class Packetizer (object): now = time.time() if now - start >= timeout: raise socket.timeout() + return x def _build_packet(self, payload): # pad up at least 4 bytes, to nearest block-size (usually 8) diff --git a/paramiko/pipe.py b/paramiko/pipe.py index cc28f43..1cfed2d 100644 --- a/paramiko/pipe.py +++ b/paramiko/pipe.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -19,6 +19,9 @@ """ Abstraction of a one-way pipe where the read end can be used in select(). Normally this is trivial, but Windows makes it nearly impossible. + +The pipe acts like an Event, which can be set or cleared. When set, the pipe +will trigger as readable in select(). """ import sys @@ -28,8 +31,10 @@ import socket def make_pipe (): if sys.platform[:3] != 'win': - return PosixPipe() - return WindowsPipe() + p = PosixPipe() + else: + p = WindowsPipe() + return p class PosixPipe (object): @@ -37,10 +42,13 @@ class PosixPipe (object): self._rfd, self._wfd = os.pipe() self._set = False self._forever = False + self._closed = False def close (self): os.close(self._rfd) os.close(self._wfd) + # used for unit tests: + self._closed = True def fileno (self): return self._rfd @@ -52,7 +60,7 @@ class PosixPipe (object): self._set = False def set (self): - if self._set: + if self._set or self._closed: return self._set = True os.write(self._wfd, '*') @@ -80,10 +88,13 @@ class WindowsPipe (object): serv.close() self._set = False self._forever = False + self._closed = False def close (self): self._rsock.close() self._wsock.close() + # used for unit tests: + self._closed = True def fileno (self): return self._rsock.fileno() @@ -95,7 +106,7 @@ class WindowsPipe (object): self._set = False def set (self): - if self._set: + if self._set or self._closed: return self._set = True self._wsock.send('*') @@ -103,3 +114,34 @@ class WindowsPipe (object): def set_forever (self): self._forever = True self.set() + + +class OrPipe (object): + def __init__(self, pipe): + self._set = False + self._partner = None + self._pipe = pipe + + def set(self): + self._set = True + if not self._partner._set: + self._pipe.set() + + def clear(self): + self._set = False + if not self._partner._set: + self._pipe.clear() + + +def make_or_pipe(pipe): + """ + wraps a pipe into two pipe-like objects which are "or"d together to + affect the real pipe. if either returned pipe is set, the wrapped pipe + is set. when both are cleared, the wrapped pipe is cleared. + """ + p1 = OrPipe(pipe) + p2 = OrPipe(pipe) + p1._partner = p2 + p2._partner = p1 + return p1, p2 + diff --git a/paramiko/pkey.py b/paramiko/pkey.py index 75db8e5..4e8b26b 100644 --- a/paramiko/pkey.py +++ b/paramiko/pkey.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -21,6 +21,7 @@ Common API for all public keys. """ import base64 +from binascii import hexlify, unhexlify import os from Crypto.Hash import MD5 @@ -139,8 +140,6 @@ class PKey (object): @return: a base64 string containing the public part of the key. @rtype: str - - @since: fearow """ return base64.encodestring(str(self)).replace('\n', '') @@ -173,7 +172,7 @@ class PKey (object): """ return False - def from_private_key_file(cl, filename, password=None): + def from_private_key_file(cls, filename, password=None): """ Create a key object by reading a private key file. If the private key is encrypted and C{password} is not C{None}, the given password @@ -182,41 +181,76 @@ class PKey (object): exist in all subclasses of PKey (such as L{RSAKey} or L{DSSKey}), but is useless on the abstract PKey class. - @param filename: name of the file to read. + @param filename: name of the file to read @type filename: str @param password: an optional password to use to decrypt the key file, if it's encrypted @type password: str - @return: a new key object based on the given private key. + @return: a new key object based on the given private key @rtype: L{PKey} - @raise IOError: if there was an error reading the file. + @raise IOError: if there was an error reading the file @raise PasswordRequiredException: if the private key file is - encrypted, and C{password} is C{None}. - @raise SSHException: if the key file is invalid. - - @since: fearow + encrypted, and C{password} is C{None} + @raise SSHException: if the key file is invalid """ - key = cl(filename=filename, password=password) + key = cls(filename=filename, password=password) return key from_private_key_file = classmethod(from_private_key_file) + def from_private_key(cls, file_obj, password=None): + """ + Create a key object by reading a private key from a file (or file-like) + object. If the private key is encrypted and C{password} is not C{None}, + the given password will be used to decrypt the key (otherwise + L{PasswordRequiredException} is thrown). + + @param file_obj: the file to read from + @type file_obj: file + @param password: an optional password to use to decrypt the key, if it's + encrypted + @type password: str + @return: a new key object based on the given private key + @rtype: L{PKey} + + @raise IOError: if there was an error reading the key + @raise PasswordRequiredException: if the private key file is encrypted, + and C{password} is C{None} + @raise SSHException: if the key file is invalid + """ + key = cls(file_obj=file_obj, password=password) + return key + from_private_key = classmethod(from_private_key) + def write_private_key_file(self, filename, password=None): """ Write private key contents into a file. If the password is not C{None}, the key is encrypted before writing. - @param filename: name of the file to write. + @param filename: name of the file to write @type filename: str - @param password: an optional password to use to encrypt the key file. + @param password: an optional password to use to encrypt the key file @type password: str - @raise IOError: if there was an error writing the file. - @raise SSHException: if the key is invalid. - - @since: fearow + @raise IOError: if there was an error writing the file + @raise SSHException: if the key is invalid + """ + raise Exception('Not implemented in PKey') + + def write_private_key(self, file_obj, password=None): """ - raise exception('Not implemented in PKey') + Write private key contents into a file (or file-like) object. If the + password is not C{None}, the key is encrypted before writing. + + @param file_obj: the file object to write into + @type file_obj: file + @param password: an optional password to use to encrypt the key + @type password: str + + @raise IOError: if there was an error writing to the file + @raise SSHException: if the key is invalid + """ + raise Exception('Not implemented in PKey') def _read_private_key_file(self, tag, filename, password=None): """ @@ -242,8 +276,12 @@ class PKey (object): @raise SSHException: if the key file is invalid. """ f = open(filename, 'r') - lines = f.readlines() + data = self._read_private_key(tag, f, password) f.close() + return data + + def _read_private_key(self, tag, f, password=None): + lines = f.readlines() start = 0 while (start < len(lines)) and (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----'): start += 1 @@ -265,9 +303,9 @@ class PKey (object): # if we trudged to the end of the file, just try to cope. try: data = base64.decodestring(''.join(lines[start:end])) - except binascii.Error, e: + except base64.binascii.Error, e: raise SSHException('base64 decoding error: ' + str(e)) - if not headers.has_key('proc-type'): + if 'proc-type' not in headers: # unencryped: done return data # encrypted keyfile: will need a password @@ -277,7 +315,7 @@ class PKey (object): encryption_type, saltstr = headers['dek-info'].split(',') except: raise SSHException('Can\'t parse DEK-info in private key file') - if not self._CIPHER_TABLE.has_key(encryption_type): + if encryption_type not in self._CIPHER_TABLE: raise SSHException('Unknown private key cipher "%s"' % encryption_type) # if no password was passed in, raise an exception pointing out that we need one if password is None: @@ -285,7 +323,7 @@ class PKey (object): cipher = self._CIPHER_TABLE[encryption_type]['cipher'] keysize = self._CIPHER_TABLE[encryption_type]['keysize'] mode = self._CIPHER_TABLE[encryption_type]['mode'] - salt = util.unhexify(saltstr) + salt = unhexlify(saltstr) key = util.generate_key_bytes(MD5, salt, password, keysize) return cipher.new(key, mode, salt).decrypt(data) @@ -310,6 +348,10 @@ class PKey (object): f = open(filename, 'w', 0600) # grrr... the mode doesn't always take hold os.chmod(filename, 0600) + self._write_private_key(tag, f, data, password) + f.close() + + def _write_private_key(self, tag, f, data, password=None): f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag) if password is not None: # since we only support one cipher here, use it @@ -327,7 +369,7 @@ class PKey (object): data += '\0' * n data = cipher.new(key, mode, salt).encrypt(data) f.write('Proc-Type: 4,ENCRYPTED\n') - f.write('DEK-Info: %s,%s\n' % (cipher_name, util.hexify(salt))) + f.write('DEK-Info: %s,%s\n' % (cipher_name, hexlify(salt).upper())) f.write('\n') s = base64.encodestring(data) # re-wrap to 64-char lines @@ -336,4 +378,3 @@ class PKey (object): f.write(s) f.write('\n') f.write('-----END %s PRIVATE KEY-----\n' % tag) - f.close() diff --git a/paramiko/primes.py b/paramiko/primes.py index 3677394..7b35736 100644 --- a/paramiko/primes.py +++ b/paramiko/primes.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -23,6 +23,7 @@ Utility functions for dealing with primes. from Crypto.Util import number from paramiko import util +from paramiko.ssh_exception import SSHException def _generate_prime(bits, randpool): @@ -39,7 +40,8 @@ def _generate_prime(bits, randpool): while not number.isPrime(n): n += 2 if util.bit_length(n) == bits: - return n + break + return n def _roll_random(rpool, n): "returns a random # from 0 to N-1" @@ -59,7 +61,8 @@ def _roll_random(rpool, n): x = chr(ord(x[0]) & hbyte_mask) + x[1:] num = util.inflate_long(x, 1) if num < n: - return num + break + return num class ModulusPack (object): @@ -75,8 +78,8 @@ class ModulusPack (object): self.randpool = rpool def _parse_modulus(self, line): - timestamp, type, tests, tries, size, generator, modulus = line.split() - type = int(type) + timestamp, mod_type, tests, tries, size, generator, modulus = line.split() + mod_type = int(mod_type) tests = int(tests) tries = int(tries) size = int(size) @@ -87,7 +90,7 @@ class ModulusPack (object): # type 2 (meets basic structural requirements) # test 4 (more than just a small-prime sieve) # tries < 100 if test & 4 (at least 100 tries of miller-rabin) - if (type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)): + if (mod_type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)): self.discarded.append((modulus, 'does not meet basic requirements')) return if generator == 0: @@ -100,7 +103,7 @@ class ModulusPack (object): if (bl != size) and (bl != size + 1): self.discarded.append((modulus, 'incorrectly reported bit length %d' % size)) return - if not self.pack.has_key(bl): + if bl not in self.pack: self.pack[bl] = [] self.pack[bl].append((generator, modulus)) diff --git a/paramiko/resource.py b/paramiko/resource.py new file mode 100644 index 0000000..a089754 --- /dev/null +++ b/paramiko/resource.py @@ -0,0 +1,72 @@ +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Resource manager. +""" + +import weakref + + +class ResourceManager (object): + """ + A registry of objects and resources that should be closed when those + objects are deleted. + + This is meant to be a safer alternative to python's C{__del__} method, + which can cause reference cycles to never be collected. Objects registered + with the ResourceManager can be collected but still free resources when + they die. + + Resources are registered using L{register}, and when an object is garbage + collected, each registered resource is closed by having its C{close()} + method called. Multiple resources may be registered per object, but a + resource will only be closed once, even if multiple objects register it. + (The last object to register it wins.) + """ + + def __init__(self): + self._table = {} + + def register(self, obj, resource): + """ + Register a resource to be closed with an object is collected. + + When the given C{obj} is garbage-collected by the python interpreter, + the C{resource} will be closed by having its C{close()} method called. + Any exceptions are ignored. + + @param obj: the object to track + @type obj: object + @param resource: the resource to close when the object is collected + @type resource: object + """ + def callback(ref): + try: + resource.close() + except: + pass + del self._table[id(resource)] + + # keep the weakref in a table so it sticks around long enough to get + # its callback called. :) + self._table[id(resource)] = weakref.ref(obj, callback) + + +# singleton +ResourceManager = ResourceManager() diff --git a/paramiko/rng.py b/paramiko/rng.py new file mode 100644 index 0000000..46329d1 --- /dev/null +++ b/paramiko/rng.py @@ -0,0 +1,112 @@ +#!/usr/bin/python +# -*- coding: ascii -*- +# Copyright (C) 2008 Dwayne C. Litzenberger <dlitz@dlitz.net> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +import sys +import threading +from Crypto.Util.randpool import RandomPool as _RandomPool + +try: + import platform +except ImportError: + platform = None # Not available using Python 2.2 + +def _strxor(a, b): + assert len(a) == len(b) + return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), a, b)) + +## +## Find a strong random entropy source, depending on the detected platform. +## WARNING TO DEVELOPERS: This will fail on some systems, but do NOT use +## Crypto.Util.randpool.RandomPool as a fall-back. RandomPool will happily run +## with very little entropy, thus _silently_ defeating any security that +## Paramiko attempts to provide. (This is current as of PyCrypto 2.0.1). +## See http://www.lag.net/pipermail/paramiko/2008-January/000599.html +## and http://www.lag.net/pipermail/paramiko/2008-April/000678.html +## + +if ((platform is not None and platform.system().lower() == 'windows') or + sys.platform == 'win32'): + # MS Windows + from paramiko import rng_win32 + rng_device = rng_win32.open_rng_device() +else: + # Assume POSIX (any system where /dev/urandom exists) + from paramiko import rng_posix + rng_device = rng_posix.open_rng_device() + + +class StrongLockingRandomPool(object): + """Wrapper around RandomPool guaranteeing strong random numbers. + + Crypto.Util.randpool.RandomPool will silently operate even if it is seeded + with little or no entropy, and it provides no prediction resistance if its + state is ever compromised throughout its runtime. It is also not thread-safe. + + This wrapper augments RandomPool by XORing its output with random bits from + the operating system, and by controlling access to the underlying + RandomPool using an exclusive lock. + """ + + def __init__(self, instance=None): + if instance is None: + instance = _RandomPool() + self.randpool = instance + self.randpool_lock = threading.Lock() + self.entropy = rng_device + + # Stir 256 bits of entropy from the RNG device into the RandomPool. + self.randpool.stir(self.entropy.read(32)) + self.entropy.randomize() + + def stir(self, s=''): + self.randpool_lock.acquire() + try: + self.randpool.stir(s) + finally: + self.randpool_lock.release() + self.entropy.randomize() + + def randomize(self, N=0): + self.randpool_lock.acquire() + try: + self.randpool.randomize(N) + finally: + self.randpool_lock.release() + self.entropy.randomize() + + def add_event(self, s=''): + self.randpool_lock.acquire() + try: + self.randpool.add_event(s) + finally: + self.randpool_lock.release() + + def get_bytes(self, N): + self.randpool_lock.acquire() + try: + randpool_data = self.randpool.get_bytes(N) + finally: + self.randpool_lock.release() + entropy_data = self.entropy.read(N) + result = _strxor(randpool_data, entropy_data) + assert len(randpool_data) == N and len(entropy_data) == N and len(result) == N + return result + +# vim:set ts=4 sw=4 sts=4 expandtab: diff --git a/paramiko/rng_posix.py b/paramiko/rng_posix.py new file mode 100644 index 0000000..1e6d72c --- /dev/null +++ b/paramiko/rng_posix.py @@ -0,0 +1,97 @@ +#!/usr/bin/python +# -*- coding: ascii -*- +# Copyright (C) 2008 Dwayne C. Litzenberger <dlitz@dlitz.net> +# Copyright (C) 2008 Open Systems Canada Limited +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +import os +import stat + +class error(Exception): + pass + +class _RNG(object): + def __init__(self, file): + self.file = file + + def read(self, bytes): + return self.file.read(bytes) + + def close(self): + return self.file.close() + + def randomize(self): + return + +def open_rng_device(device_path=None): + """Open /dev/urandom and perform some sanity checks.""" + + f = None + g = None + + if device_path is None: + device_path = "/dev/urandom" + + try: + # Try to open /dev/urandom now so that paramiko will be able to access + # it even if os.chroot() is invoked later. + try: + f = open(device_path, "rb", 0) + except EnvironmentError: + raise error("Unable to open /dev/urandom") + + # Open a second file descriptor for sanity checking later. + try: + g = open(device_path, "rb", 0) + except EnvironmentError: + raise error("Unable to open /dev/urandom") + + # Check that /dev/urandom is a character special device, not a regular file. + st = os.fstat(f.fileno()) # f + if stat.S_ISREG(st.st_mode) or not stat.S_ISCHR(st.st_mode): + raise error("/dev/urandom is not a character special device") + + st = os.fstat(g.fileno()) # g + if stat.S_ISREG(st.st_mode) or not stat.S_ISCHR(st.st_mode): + raise error("/dev/urandom is not a character special device") + + # Check that /dev/urandom always returns the number of bytes requested + x = f.read(20) + y = g.read(20) + if len(x) != 20 or len(y) != 20: + raise error("Error reading from /dev/urandom: input truncated") + + # Check that different reads return different data + if x == y: + raise error("/dev/urandom is broken; returning identical data: %r == %r" % (x, y)) + + # Close the duplicate file object + g.close() + + # Return the first file object + return _RNG(f) + + except error: + if f is not None: + f.close() + if g is not None: + g.close() + raise + +# vim:set ts=4 sw=4 sts=4 expandtab: + diff --git a/paramiko/rng_win32.py b/paramiko/rng_win32.py new file mode 100644 index 0000000..3cb8b84 --- /dev/null +++ b/paramiko/rng_win32.py @@ -0,0 +1,121 @@ +#!/usr/bin/python +# -*- coding: ascii -*- +# Copyright (C) 2008 Dwayne C. Litzenberger <dlitz@dlitz.net> +# Copyright (C) 2008 Open Systems Canada Limited +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +class error(Exception): + pass + +# Try to import the "winrandom" module +try: + from Crypto.Util import winrandom as _winrandom +except ImportError: + _winrandom = None + +# Try to import the "urandom" module +try: + from os import urandom as _urandom +except ImportError: + _urandom = None + + +class _RNG(object): + def __init__(self, readfunc): + self.read = readfunc + + def randomize(self): + # According to "Cryptanalysis of the Random Number Generator of the + # Windows Operating System", by Leo Dorrendorf and Zvi Gutterman + # and Benny Pinkas <http://eprint.iacr.org/2007/419>, + # CryptGenRandom only updates its internal state using kernel-provided + # random data every 128KiB of output. + self.read(128*1024) # discard 128 KiB of output + +def _open_winrandom(): + if _winrandom is None: + raise error("Crypto.Util.winrandom module not found") + + # Check that we can open the winrandom module + try: + r0 = _winrandom.new() + r1 = _winrandom.new() + except Exception, exc: + raise error("winrandom.new() failed: %s" % str(exc), exc) + + # Check that we can read from the winrandom module + try: + x = r0.get_bytes(20) + y = r1.get_bytes(20) + except Exception, exc: + raise error("winrandom get_bytes failed: %s" % str(exc), exc) + + # Check that the requested number of bytes are returned + if len(x) != 20 or len(y) != 20: + raise error("Error reading from winrandom: input truncated") + + # Check that different reads return different data + if x == y: + raise error("winrandom broken: returning identical data") + + return _RNG(r0.get_bytes) + +def _open_urandom(): + if _urandom is None: + raise error("os.urandom function not found") + + # Check that we can read from os.urandom() + try: + x = _urandom(20) + y = _urandom(20) + except Exception, exc: + raise error("os.urandom failed: %s" % str(exc), exc) + + # Check that the requested number of bytes are returned + if len(x) != 20 or len(y) != 20: + raise error("os.urandom failed: input truncated") + + # Check that different reads return different data + if x == y: + raise error("os.urandom failed: returning identical data") + + return _RNG(_urandom) + +def open_rng_device(): + # Try using the Crypto.Util.winrandom module + try: + return _open_winrandom() + except error: + pass + + # Several versions of PyCrypto do not contain the winrandom module, but + # Python >= 2.4 has os.urandom, so try to use that. + try: + return _open_urandom() + except error: + pass + + # SECURITY NOTE: DO NOT USE Crypto.Util.randpool.RandomPool HERE! + # If we got to this point, RandomPool will silently run with very little + # entropy. (This is current as of PyCrypto 2.0.1). + # See http://www.lag.net/pipermail/paramiko/2008-January/000599.html + # and http://www.lag.net/pipermail/paramiko/2008-April/000678.html + + raise error("Unable to find a strong random entropy source. You cannot run this software securely under the current configuration.") + +# vim:set ts=4 sw=4 sts=4 expandtab: diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py index 780ea1b..d72d175 100644 --- a/paramiko/rsakey.py +++ b/paramiko/rsakey.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -38,7 +38,15 @@ class RSAKey (PKey): data. """ - def __init__(self, msg=None, data=None, filename=None, password=None, vals=None): + def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None): + self.n = None + self.e = None + self.d = None + self.p = None + self.q = None + if file_obj is not None: + self._from_private_key(file_obj, password) + return if filename is not None: self._from_private_key_file(filename, password) return @@ -75,7 +83,7 @@ class RSAKey (PKey): return self.size def can_sign(self): - return hasattr(self, 'd') + return self.d is not None def sign_ssh_data(self, rpool, data): digest = SHA.new(data).digest() @@ -93,11 +101,13 @@ class RSAKey (PKey): # verify the signature by SHA'ing the data and encrypting it using the # public key. some wackiness ensues where we "pkcs1imify" the 20-byte # hash into a string as long as the RSA key. - hash = util.inflate_long(self._pkcs1imify(SHA.new(data).digest()), True) + hash_obj = util.inflate_long(self._pkcs1imify(SHA.new(data).digest()), True) rsa = RSA.construct((long(self.n), long(self.e))) - return rsa.verify(hash, (sig,)) + return rsa.verify(hash_obj, (sig,)) - def write_private_key_file(self, filename, password=None): + def _encode_key(self): + if (self.p is None) or (self.q is None): + raise SSHException('Not enough key info to write private key file') keylist = [ 0, self.n, self.e, self.d, self.p, self.q, self.d % (self.p - 1), self.d % (self.q - 1), util.mod_inverse(self.q, self.p) ] @@ -106,7 +116,13 @@ class RSAKey (PKey): b.encode(keylist) except BERException: raise SSHException('Unable to create ber encoding of key') - self._write_private_key_file('RSA', filename, str(b), password) + return str(b) + + def write_private_key_file(self, filename, password=None): + self._write_private_key_file('RSA', filename, self._encode_key(), password) + + def write_private_key(self, file_obj, password=None): + self._write_private_key('RSA', file_obj, self._encode_key(), password) def generate(bits, progress_func=None): """ @@ -120,8 +136,6 @@ class RSAKey (PKey): @type progress_func: function @return: new private key @rtype: L{RSAKey} - - @since: fearow """ randpool.stir() rsa = RSA.generate(bits, randpool.get_bytes, progress_func) @@ -147,9 +161,16 @@ class RSAKey (PKey): return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data def _from_private_key_file(self, filename, password): + data = self._read_private_key_file('RSA', filename, password) + self._decode_key(data) + + def _from_private_key(self, file_obj, password): + data = self._read_private_key('RSA', file_obj, password) + self._decode_key(data) + + def _decode_key(self, data): # private key file contains: # RSAPrivateKey = { version = 0, n, e, d, p, q, d mod p-1, d mod q-1, q**-1 mod p } - data = self._read_private_key_file('RSA', filename, password) try: keylist = BER(data).decode() except BERException: diff --git a/paramiko/server.py b/paramiko/server.py index a0e3988..bcaa4be 100644 --- a/paramiko/server.py +++ b/paramiko/server.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -41,6 +41,8 @@ class InteractiveQuery (object): @type name: str @param instructions: user instructions (usually short) about this query @type instructions: str + @param prompts: one or more authentication prompts + @type prompts: str """ self.name = name self.instructions = instructions @@ -90,6 +92,7 @@ class ServerInterface (object): - L{check_channel_shell_request} - L{check_channel_subsystem_request} - L{check_channel_window_change_request} + - L{check_channel_x11_request} The C{chanid} parameter is a small number that uniquely identifies the channel within a L{Transport}. A L{Channel} object is not created @@ -273,6 +276,42 @@ class ServerInterface (object): """ return AUTH_FAILED + def check_port_forward_request(self, address, port): + """ + Handle a request for port forwarding. The client is asking that + connections to the given address and port be forwarded back across + this ssh connection. An address of C{"0.0.0.0"} indicates a global + address (any address associated with this server) and a port of C{0} + indicates that no specific port is requested (usually the OS will pick + a port). + + The default implementation always returns C{False}, rejecting the + port forwarding request. If the request is accepted, you should return + the port opened for listening. + + @param address: the requested address + @type address: str + @param port: the requested port + @type port: int + @return: the port number that was opened for listening, or C{False} to + reject + @rtype: int + """ + return False + + def cancel_port_forward_request(self, address, port): + """ + The client would like to cancel a previous port-forwarding request. + If the given address and port is being forwarded across this ssh + connection, the port should be closed. + + @param address: the forwarded address + @type address: str + @param port: the forwarded port + @type port: int + """ + pass + def check_global_request(self, kind, msg): """ Handle a global request of the given C{kind}. This method is called @@ -291,6 +330,9 @@ class ServerInterface (object): The default implementation always returns C{False}, indicating that it does not support any global requests. + + @note: Port forwarding requests are handled separately, in + L{check_port_forward_request}. @param kind: the kind of global request being made. @type kind: str @@ -426,6 +468,71 @@ class ServerInterface (object): @rtype: bool """ return False + + def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): + """ + Determine if the client will be provided with an X11 session. If this + method returns C{True}, X11 applications should be routed through new + SSH channels, using L{Transport.open_x11_channel}. + + The default implementation always returns C{False}. + + @param channel: the L{Channel} the X11 request arrived on + @type channel: L{Channel} + @param single_connection: C{True} if only a single X11 channel should + be opened + @type single_connection: bool + @param auth_protocol: the protocol used for X11 authentication + @type auth_protocol: str + @param auth_cookie: the cookie used to authenticate to X11 + @type auth_cookie: str + @param screen_number: the number of the X11 screen to connect to + @type screen_number: int + @return: C{True} if the X11 session was opened; C{False} if not + @rtype: bool + """ + return False + + def check_channel_direct_tcpip_request(self, chanid, origin, destination): + """ + Determine if a local port forwarding channel will be granted, and + return C{OPEN_SUCCEEDED} or an error code. This method is + called in server mode when the client requests a channel, after + authentication is complete. + + The C{chanid} parameter is a small number that uniquely identifies the + channel within a L{Transport}. A L{Channel} object is not created + unless this method returns C{OPEN_SUCCEEDED} -- once a + L{Channel} object is created, you can call L{Channel.get_id} to + retrieve the channel ID. + + The origin and destination parameters are (ip_address, port) tuples + that correspond to both ends of the TCP connection in the forwarding + tunnel. + + The return value should either be C{OPEN_SUCCEEDED} (or + C{0}) to allow the channel request, or one of the following error + codes to reject it: + - C{OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED} + - C{OPEN_FAILED_CONNECT_FAILED} + - C{OPEN_FAILED_UNKNOWN_CHANNEL_TYPE} + - C{OPEN_FAILED_RESOURCE_SHORTAGE} + + The default implementation always returns + C{OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED}. + + @param chanid: ID of the channel + @type chanid: int + @param origin: 2-tuple containing the IP address and port of the + originator (client side) + @type origin: tuple + @param destination: 2-tuple containing the IP address and port of the + destination (server side) + @type destination: tuple + @return: a success or failure code (listed above) + @rtype: int + """ + return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED class SubsystemHandler (threading.Thread): @@ -443,8 +550,6 @@ class SubsystemHandler (threading.Thread): authenticated and requests subsytem C{"mp3"}, an object of class C{MP3Handler} will be created, and L{start_subsystem} will be called on it from a new thread. - - @since: ivysaur """ def __init__(self, channel, name, server): """ diff --git a/paramiko/sftp.py b/paramiko/sftp.py index 58d7103..2296d85 100644 --- a/paramiko/sftp.py +++ b/paramiko/sftp.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -16,6 +16,7 @@ # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. +import select import socket import struct @@ -113,24 +114,22 @@ class BaseSFTP (object): return version def _send_server_version(self): + # winscp will freak out if the server sends version info before the + # client finishes sending INIT. + t, data = self._read_packet() + if t != CMD_INIT: + raise SFTPError('Incompatible sftp protocol') + version = struct.unpack('>I', data[:4])[0] # advertise that we support "check-file" extension_pairs = [ 'check-file', 'md5,sha1' ] msg = Message() msg.add_int(_VERSION) msg.add(*extension_pairs) self._send_packet(CMD_VERSION, str(msg)) - t, data = self._read_packet() - if t != CMD_INIT: - raise SFTPError('Incompatible sftp protocol') - version = struct.unpack('>I', data[:4])[0] return version - def _log(self, level, msg): - if issubclass(type(msg), list): - for m in msg: - self.logger.log(level, m) - else: - self.logger.log(level, msg) + def _log(self, level, msg, *args): + self.logger.log(level, msg, *args) def _write_all(self, out): while len(out) > 0: @@ -145,7 +144,20 @@ class BaseSFTP (object): def _read_all(self, n): out = '' while n > 0: - x = self.sock.recv(n) + if isinstance(self.sock, socket.socket): + # sometimes sftp is used directly over a socket instead of + # through a paramiko channel. in this case, check periodically + # if the socket is closed. (for some reason, recv() won't ever + # return or raise an exception, but calling select on a closed + # socket will.) + while True: + read, write, err = select.select([ self.sock ], [], [], 0.1) + if len(read) > 0: + x = self.sock.recv(n) + break + else: + x = self.sock.recv(n) + if len(x) == 0: raise EOFError() out += x @@ -153,16 +165,24 @@ class BaseSFTP (object): return out def _send_packet(self, t, packet): + #self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet))) out = struct.pack('>I', len(packet) + 1) + chr(t) + packet if self.ultra_debug: self._log(DEBUG, util.format_binary(out, 'OUT: ')) self._write_all(out) def _read_packet(self): - size = struct.unpack('>I', self._read_all(4))[0] + x = self._read_all(4) + # most sftp servers won't accept packets larger than about 32k, so + # anything with the high byte set (> 16MB) is just garbage. + if x[0] != '\x00': + raise SFTPError('Garbage packet received') + size = struct.unpack('>I', x)[0] data = self._read_all(size) if self.ultra_debug: self._log(DEBUG, util.format_binary(data, 'IN: ')); if size > 0: - return ord(data[0]), data[1:] + t = ord(data[0]) + #self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1)) + return t, data[1:] return 0, '' diff --git a/paramiko/sftp_attr.py b/paramiko/sftp_attr.py index eae7c99..9c92862 100644 --- a/paramiko/sftp_attr.py +++ b/paramiko/sftp_attr.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2006 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -51,6 +51,12 @@ class SFTPAttributes (object): Create a new (empty) SFTPAttributes object. All fields will be empty. """ self._flags = 0 + self.st_size = None + self.st_uid = None + self.st_gid = None + self.st_mode = None + self.st_atime = None + self.st_mtime = None self.attr = {} def from_stat(cls, obj, filename=None): @@ -80,18 +86,17 @@ class SFTPAttributes (object): def __repr__(self): return '<SFTPAttributes: %s>' % self._debug_str() - def __str__(self): - return self._debug_str() - ### internals... - def _from_msg(cls, msg, filename=None): + def _from_msg(cls, msg, filename=None, longname=None): attr = cls() attr._unpack(msg) if filename is not None: attr.filename = filename + if longname is not None: + attr.longname = longname return attr _from_msg = classmethod(_from_msg) @@ -114,13 +119,13 @@ class SFTPAttributes (object): def _pack(self, msg): self._flags = 0 - if hasattr(self, 'st_size'): + if self.st_size is not None: self._flags |= self.FLAG_SIZE - if hasattr(self, 'st_uid') or hasattr(self, 'st_gid'): + if (self.st_uid is not None) and (self.st_gid is not None): self._flags |= self.FLAG_UIDGID - if hasattr(self, 'st_mode'): + if self.st_mode is not None: self._flags |= self.FLAG_PERMISSIONS - if hasattr(self, 'st_atime') or hasattr(self, 'st_mtime'): + if (self.st_atime is not None) and (self.st_mtime is not None): self._flags |= self.FLAG_AMTIME if len(self.attr) > 0: self._flags |= self.FLAG_EXTENDED @@ -128,13 +133,14 @@ class SFTPAttributes (object): if self._flags & self.FLAG_SIZE: msg.add_int64(self.st_size) if self._flags & self.FLAG_UIDGID: - msg.add_int(getattr(self, 'st_uid', 0)) - msg.add_int(getattr(self, 'st_gid', 0)) + msg.add_int(self.st_uid) + msg.add_int(self.st_gid) if self._flags & self.FLAG_PERMISSIONS: msg.add_int(self.st_mode) if self._flags & self.FLAG_AMTIME: - msg.add_int(getattr(self, 'st_atime', 0)) - msg.add_int(getattr(self, 'st_mtime', 0)) + # throw away any fractional seconds + msg.add_int(long(self.st_atime)) + msg.add_int(long(self.st_mtime)) if self._flags & self.FLAG_EXTENDED: msg.add_int(len(self.attr)) for key, val in self.attr.iteritems(): @@ -144,15 +150,14 @@ class SFTPAttributes (object): def _debug_str(self): out = '[ ' - if hasattr(self, 'st_size'): + if self.st_size is not None: out += 'size=%d ' % self.st_size - if hasattr(self, 'st_uid') or hasattr(self, 'st_gid'): - out += 'uid=%d gid=%d ' % (getattr(self, 'st_uid', 0), getattr(self, 'st_gid', 0)) - if hasattr(self, 'st_mode'): + if (self.st_uid is not None) and (self.st_gid is not None): + out += 'uid=%d gid=%d ' % (self.st_uid, self.st_gid) + if self.st_mode is not None: out += 'mode=' + oct(self.st_mode) + ' ' - if hasattr(self, 'st_atime') or hasattr(self, 'st_mtime'): - out += 'atime=%d mtime=%d ' % (getattr(self, 'st_atime', 0), - getattr(self, 'st_mtime', 0)) + if (self.st_atime is not None) and (self.st_mtime is not None): + out += 'atime=%d mtime=%d ' % (self.st_atime, self.st_mtime) for k, v in self.attr.iteritems(): out += '"%s"=%r ' % (str(k), v) out += ']' @@ -171,7 +176,7 @@ class SFTPAttributes (object): def __str__(self): "create a unix-style long description of the file (like ls -l)" - if hasattr(self, 'st_mode'): + if self.st_mode is not None: kind = stat.S_IFMT(self.st_mode) if kind == stat.S_IFIFO: ks = 'p' @@ -194,15 +199,25 @@ class SFTPAttributes (object): ks += self._rwx(self.st_mode & 7, self.st_mode & stat.S_ISVTX, True) else: ks = '?---------' - uid = getattr(self, 'st_uid', -1) - gid = getattr(self, 'st_gid', -1) - size = getattr(self, 'st_size', -1) - mtime = getattr(self, 'st_mtime', 0) # compute display date - if abs(time.time() - mtime) > 15552000: - # (15552000 = 6 months) - datestr = time.strftime('%d %b %Y', time.localtime(mtime)) + if (self.st_mtime is None) or (self.st_mtime == 0xffffffff): + # shouldn't really happen + datestr = '(unknown date)' else: - datestr = time.strftime('%d %b %H:%M', time.localtime(mtime)) + if abs(time.time() - self.st_mtime) > 15552000: + # (15552000 = 6 months) + datestr = time.strftime('%d %b %Y', time.localtime(self.st_mtime)) + else: + datestr = time.strftime('%d %b %H:%M', time.localtime(self.st_mtime)) filename = getattr(self, 'filename', '?') - return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, size, datestr, filename) + + # not all servers support uid/gid + uid = self.st_uid + gid = self.st_gid + if uid is None: + uid = 0 + if gid is None: + gid = 0 + + return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, self.st_size, datestr, filename) + diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index 2fe89e9..b3d2d56 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -20,21 +20,32 @@ Client-mode SFTP support. """ +from binascii import hexlify import errno import os import threading +import time import weakref + from paramiko.sftp import * from paramiko.sftp_attr import SFTPAttributes +from paramiko.ssh_exception import SSHException from paramiko.sftp_file import SFTPFile def _to_unicode(s): - "if a str is not ascii, decode its utf8 into unicode" + """ + decode a string as ascii or utf8 if possible (as required by the sftp + protocol). if neither works, just return a byte string because the server + probably doesn't know the filename's encoding. + """ try: return s.encode('ascii') - except: - return s.decode('utf-8') + except UnicodeError: + try: + return s.decode('utf-8') + except UnicodeError: + return s class SFTPClient (BaseSFTP): @@ -51,8 +62,11 @@ class SFTPClient (BaseSFTP): An alternate way to create an SFTP client context is by using L{from_transport}. - @param sock: an open L{Channel} using the C{"sftp"} subsystem. + @param sock: an open L{Channel} using the C{"sftp"} subsystem @type sock: L{Channel} + + @raise SSHException: if there's an exception while negotiating + sftp """ BaseSFTP.__init__(self) self.sock = sock @@ -66,31 +80,33 @@ class SFTPClient (BaseSFTP): if type(sock) is Channel: # override default logger transport = self.sock.get_transport() - self.logger = util.get_logger(transport.get_log_channel() + '.' + - self.sock.get_name() + '.sftp') + self.logger = util.get_logger(transport.get_log_channel() + '.sftp') self.ultra_debug = transport.get_hexdump() - self._send_version() - - def __del__(self): - self.close() + try: + server_version = self._send_version() + except EOFError, x: + raise SSHException('EOF during negotiation') + self._log(INFO, 'Opened sftp connection (server version %d)' % server_version) - def from_transport(selfclass, t): + def from_transport(cls, t): """ Create an SFTP client channel from an open L{Transport}. - @param t: an open L{Transport} which is already authenticated. + @param t: an open L{Transport} which is already authenticated @type t: L{Transport} @return: a new L{SFTPClient} object, referring to an sftp session - (channel) across the transport. + (channel) across the transport @rtype: L{SFTPClient} """ chan = t.open_session() if chan is None: return None - if not chan.invoke_subsystem('sftp'): - raise SFTPError('Failed to invoke sftp subsystem') - return selfclass(chan) + chan.invoke_subsystem('sftp') + return cls(chan) from_transport = classmethod(from_transport) + + def _log(self, level, msg, *args): + super(SFTPClient, self)._log(level, "[chan %s] " + msg, *([ self.sock.get_name() ] + list(args))) def close(self): """ @@ -98,7 +114,20 @@ class SFTPClient (BaseSFTP): @since: 1.4 """ + self._log(INFO, 'sftp session closed.') self.sock.close() + + def get_channel(self): + """ + Return the underlying L{Channel} object for this SFTP session. This + might be useful for doing things like setting a timeout on the channel. + + @return: the SSH channel + @rtype: L{Channel} + + @since: 1.7.1 + """ + return self.sock def listdir(self, path='.'): """ @@ -121,6 +150,11 @@ class SFTPClient (BaseSFTP): files in the given C{path}. The list is in arbitrary order. It does not include the special entries C{'.'} and C{'..'} even if they are present in the folder. + + The returned L{SFTPAttributes} objects will each have an additional + field: C{longname}, which may contain a formatted string of the file's + attributes, in unix format. The content of this string will probably + depend on the SFTP server implementation. @param path: path to list (defaults to C{'.'}) @type path: str @@ -130,6 +164,7 @@ class SFTPClient (BaseSFTP): @since: 1.2 """ path = self._adjust_cwd(path) + self._log(DEBUG, 'listdir(%r)' % path) t, msg = self._request(CMD_OPENDIR, path) if t != CMD_HANDLE: raise SFTPError('Expected handle') @@ -147,13 +182,13 @@ class SFTPClient (BaseSFTP): for i in range(count): filename = _to_unicode(msg.get_string()) longname = _to_unicode(msg.get_string()) - attr = SFTPAttributes._from_msg(msg, filename) + attr = SFTPAttributes._from_msg(msg, filename, longname) if (filename != '.') and (filename != '..'): filelist.append(attr) self._request(CMD_CLOSE, handle) return filelist - def file(self, filename, mode='r', bufsize=-1): + def open(self, filename, mode='r', bufsize=-1): """ Open a file on the remote server. The arguments are the same as for python's built-in C{file} (aka C{open}). A file-like object is @@ -177,18 +212,19 @@ class SFTPClient (BaseSFTP): buffering, C{1} uses line buffering, and any number greater than 1 (C{>1}) uses that specific buffer size. - @param filename: name of the file to open. - @type filename: string - @param mode: mode (python-style) to open in. - @type mode: string + @param filename: name of the file to open + @type filename: str + @param mode: mode (python-style) to open in + @type mode: str @param bufsize: desired buffering (-1 = default buffer size) @type bufsize: int - @return: a file object representing the open file. + @return: a file object representing the open file @rtype: SFTPFile @raise IOError: if the file could not be opened. """ filename = self._adjust_cwd(filename) + self._log(DEBUG, 'open(%r, %r)' % (filename, mode)) imode = 0 if ('r' in mode) or ('+' in mode): imode |= SFTP_FLAG_READ @@ -205,23 +241,24 @@ class SFTPClient (BaseSFTP): if t != CMD_HANDLE: raise SFTPError('Expected handle') handle = msg.get_string() + self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, hexlify(handle))) return SFTPFile(self, handle, mode, bufsize) - # python has migrated toward file() instead of open(). - # and really, that's more easily identifiable. - open = file + # python continues to vacillate about "open" vs "file"... + file = open def remove(self, path): """ - Remove the file at the given path. + Remove the file at the given path. This only works on files; for + removing folders (directories), use L{rmdir}. - @param path: path (absolute or relative) of the file to remove. - @type path: string + @param path: path (absolute or relative) of the file to remove + @type path: str - @raise IOError: if the path refers to a folder (directory). Use - L{rmdir} to remove a folder. + @raise IOError: if the path refers to a folder (directory) """ path = self._adjust_cwd(path) + self._log(DEBUG, 'remove(%r)' % path) self._request(CMD_REMOVE, path) unlink = remove @@ -230,16 +267,17 @@ class SFTPClient (BaseSFTP): """ Rename a file or folder from C{oldpath} to C{newpath}. - @param oldpath: existing name of the file or folder. - @type oldpath: string - @param newpath: new name for the file or folder. - @type newpath: string + @param oldpath: existing name of the file or folder + @type oldpath: str + @param newpath: new name for the file or folder + @type newpath: str @raise IOError: if C{newpath} is a folder, or something else goes - wrong. + wrong """ oldpath = self._adjust_cwd(oldpath) newpath = self._adjust_cwd(newpath) + self._log(DEBUG, 'rename(%r, %r)' % (oldpath, newpath)) self._request(CMD_RENAME, oldpath, newpath) def mkdir(self, path, mode=0777): @@ -248,12 +286,13 @@ class SFTPClient (BaseSFTP): The default mode is 0777 (octal). On some systems, mode is ignored. Where it is used, the current umask value is first masked out. - @param path: name of the folder to create. - @type path: string - @param mode: permissions (posix-style) for the newly-created folder. + @param path: name of the folder to create + @type path: str + @param mode: permissions (posix-style) for the newly-created folder @type mode: int """ path = self._adjust_cwd(path) + self._log(DEBUG, 'mkdir(%r, %r)' % (path, mode)) attr = SFTPAttributes() attr.st_mode = mode self._request(CMD_MKDIR, path, attr) @@ -262,10 +301,11 @@ class SFTPClient (BaseSFTP): """ Remove the folder named C{path}. - @param path: name of the folder to remove. - @type path: string + @param path: name of the folder to remove + @type path: str """ path = self._adjust_cwd(path) + self._log(DEBUG, 'rmdir(%r)' % path) self._request(CMD_RMDIR, path) def stat(self, path): @@ -282,12 +322,13 @@ class SFTPClient (BaseSFTP): The fields supported are: C{st_mode}, C{st_size}, C{st_uid}, C{st_gid}, C{st_atime}, and C{st_mtime}. - @param path: the filename to stat. - @type path: string - @return: an object containing attributes about the given file. + @param path: the filename to stat + @type path: str + @return: an object containing attributes about the given file @rtype: SFTPAttributes """ path = self._adjust_cwd(path) + self._log(DEBUG, 'stat(%r)' % path) t, msg = self._request(CMD_STAT, path) if t != CMD_ATTRS: raise SFTPError('Expected attributes') @@ -299,12 +340,13 @@ class SFTPClient (BaseSFTP): following symbolic links (shortcuts). This otherwise behaves exactly the same as L{stat}. - @param path: the filename to stat. - @type path: string - @return: an object containing attributes about the given file. + @param path: the filename to stat + @type path: str + @return: an object containing attributes about the given file @rtype: SFTPAttributes """ path = self._adjust_cwd(path) + self._log(DEBUG, 'lstat(%r)' % path) t, msg = self._request(CMD_LSTAT, path) if t != CMD_ATTRS: raise SFTPError('Expected attributes') @@ -315,12 +357,13 @@ class SFTPClient (BaseSFTP): Create a symbolic link (shortcut) of the C{source} path at C{destination}. - @param source: path of the original file. - @type source: string - @param dest: path of the newly created symlink. - @type dest: string + @param source: path of the original file + @type source: str + @param dest: path of the newly created symlink + @type dest: str """ dest = self._adjust_cwd(dest) + self._log(DEBUG, 'symlink(%r, %r)' % (source, dest)) if type(source) is unicode: source = source.encode('utf-8') self._request(CMD_SYMLINK, source, dest) @@ -331,12 +374,13 @@ class SFTPClient (BaseSFTP): unix-style and identical to those used by python's C{os.chmod} function. - @param path: path of the file to change the permissions of. - @type path: string - @param mode: new permissions. + @param path: path of the file to change the permissions of + @type path: str + @param mode: new permissions @type mode: int """ path = self._adjust_cwd(path) + self._log(DEBUG, 'chmod(%r, %r)' % (path, mode)) attr = SFTPAttributes() attr.st_mode = mode self._request(CMD_SETSTAT, path, attr) @@ -348,14 +392,15 @@ class SFTPClient (BaseSFTP): only want to change one, use L{stat} first to retrieve the current owner and group. - @param path: path of the file to change the owner and group of. - @type path: string + @param path: path of the file to change the owner and group of + @type path: str @param uid: new owner's uid @type uid: int @param gid: new group id @type gid: int """ path = self._adjust_cwd(path) + self._log(DEBUG, 'chown(%r, %r, %r)' % (path, uid, gid)) attr = SFTPAttributes() attr.st_uid, attr.st_gid = uid, gid self._request(CMD_SETSTAT, path, attr) @@ -369,31 +414,50 @@ class SFTPClient (BaseSFTP): modified times, respectively. This bizarre API is mimicked from python for the sake of consistency -- I apologize. - @param path: path of the file to modify. - @type path: string + @param path: path of the file to modify + @type path: str @param times: C{None} or a tuple of (access time, modified time) in - standard internet epoch time (seconds since 01 January 1970 GMT). - @type times: tuple of int + standard internet epoch time (seconds since 01 January 1970 GMT) + @type times: tuple(int) """ path = self._adjust_cwd(path) if times is None: times = (time.time(), time.time()) + self._log(DEBUG, 'utime(%r, %r)' % (path, times)) attr = SFTPAttributes() attr.st_atime, attr.st_mtime = times self._request(CMD_SETSTAT, path, attr) + def truncate(self, path, size): + """ + Change the size of the file specified by C{path}. This usually extends + or shrinks the size of the file, just like the C{truncate()} method on + python file objects. + + @param path: path of the file to modify + @type path: str + @param size: the new size of the file + @type size: int or long + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'truncate(%r, %r)' % (path, size)) + attr = SFTPAttributes() + attr.st_size = size + self._request(CMD_SETSTAT, path, attr) + def readlink(self, path): """ Return the target of a symbolic link (shortcut). You can use L{symlink} to create these. The result may be either an absolute or relative pathname. - @param path: path of the symbolic link file. + @param path: path of the symbolic link file @type path: str - @return: target path. + @return: target path @rtype: str """ path = self._adjust_cwd(path) + self._log(DEBUG, 'readlink(%r)' % path) t, msg = self._request(CMD_READLINK, path) if t != CMD_NAME: raise SFTPError('Expected name response') @@ -411,14 +475,15 @@ class SFTPClient (BaseSFTP): server is considering to be the "current folder" (by passing C{'.'} as C{path}). - @param path: path to be normalized. + @param path: path to be normalized @type path: str - @return: normalized form of the given path. + @return: normalized form of the given path @rtype: str @raise IOError: if the path can't be resolved on the server """ path = self._adjust_cwd(path) + self._log(DEBUG, 'normalize(%r)' % path) t, msg = self._request(CMD_REALPATH, path) if t != CMD_NAME: raise SFTPError('Expected name response') @@ -457,7 +522,7 @@ class SFTPClient (BaseSFTP): """ return self._cwd - def put(self, localpath, remotepath): + def put(self, localpath, remotepath, callback=None): """ Copy a local file (C{localpath}) to the SFTP server as C{remotepath}. Any exception raised by operations will be passed through. This @@ -469,9 +534,17 @@ class SFTPClient (BaseSFTP): @type localpath: str @param remotepath: the destination path on the SFTP server @type remotepath: str + @param callback: optional callback function that accepts the bytes + transferred so far and the total bytes to be transferred + (since 1.7.4) + @type callback: function(int, int) + @return: an object containing attributes about the given file + (since 1.7.4) + @rtype: SFTPAttributes @since: 1.4 """ + file_size = os.stat(localpath).st_size fl = file(localpath, 'rb') fr = self.file(remotepath, 'wb') fr.set_pipelined(True) @@ -482,13 +555,16 @@ class SFTPClient (BaseSFTP): break fr.write(data) size += len(data) + if callback is not None: + callback(size, file_size) fl.close() fr.close() s = self.stat(remotepath) if s.st_size != size: raise IOError('size mismatch in put! %d != %d' % (s.st_size, size)) + return s - def get(self, remotepath, localpath): + def get(self, remotepath, localpath, callback=None): """ Copy a remote file (C{remotepath}) from the SFTP server to the local host as C{localpath}. Any exception raised by operations will be @@ -498,10 +574,15 @@ class SFTPClient (BaseSFTP): @type remotepath: str @param localpath: the destination path on the local host @type localpath: str + @param callback: optional callback function that accepts the bytes + transferred so far and the total bytes to be transferred + (since 1.7.4) + @type callback: function(int, int) @since: 1.4 """ fr = self.file(remotepath, 'rb') + file_size = self.stat(remotepath).st_size fr.prefetch() fl = file(localpath, 'wb') size = 0 @@ -511,6 +592,8 @@ class SFTPClient (BaseSFTP): break fl.write(data) size += len(data) + if callback is not None: + callback(size, file_size) fl.close() fr.close() s = os.stat(localpath) @@ -552,7 +635,10 @@ class SFTPClient (BaseSFTP): def _read_response(self, waitfor=None): while True: - t, data = self._read_packet() + try: + t, data = self._read_packet() + except EOFError, e: + raise SSHException('Server connection dropped: %s' % (str(e),)) msg = Message(data) num = msg.get_int() if num not in self._expecting: @@ -560,7 +646,7 @@ class SFTPClient (BaseSFTP): self._log(DEBUG, 'Unexpected response #%d' % (num,)) if waitfor is None: # just doing a single check - return + break continue fileobj = self._expecting[num] del self._expecting[num] @@ -573,7 +659,8 @@ class SFTPClient (BaseSFTP): fileobj._async_response(t, msg) if waitfor is None: # just doing a single check - return + break + return (None, None) def _finish_responses(self, fileobj): while fileobj in self._expecting.values(): @@ -610,6 +697,8 @@ class SFTPClient (BaseSFTP): if (len(path) > 0) and (path[0] == '/'): # absolute path return path + if self._cwd == '/': + return self._cwd + path return self._cwd + '/' + path diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py index f224f02..cfa7db1 100644 --- a/paramiko/sftp_file.py +++ b/paramiko/sftp_file.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -20,7 +20,11 @@ L{SFTPFile} """ +from binascii import hexlify +import socket import threading +import time + from paramiko.common import * from paramiko.sftp import * from paramiko.file import BufferedFile @@ -43,12 +47,18 @@ class SFTPFile (BufferedFile): BufferedFile._set_mode(self, mode, bufsize) self.pipelined = False self._prefetching = False + self._prefetch_done = False + self._prefetch_data = {} + self._prefetch_reads = [] self._saved_exception = None def __del__(self): - self.close(_async=True) + self._close(async=True) + + def close(self): + self._close(async=False) - def close(self, _async=False): + def _close(self, async=False): # We allow double-close without signaling an error, because real # Python file objects do. However, we must protect against actually # sending multiple CMD_CLOSE packets, because after we close our @@ -58,11 +68,12 @@ class SFTPFile (BufferedFile): # __del__.) if self._closed: return + self.sftp._log(DEBUG, 'close(%s)' % hexlify(self.handle)) if self.pipelined: self.sftp._finish_responses(self) BufferedFile.close(self) try: - if _async: + if async: # GC'd file handle could be called from an arbitrary thread -- don't wait for a response self.sftp._async_request(type(None), CMD_CLOSE, self.handle) else: @@ -70,34 +81,77 @@ class SFTPFile (BufferedFile): except EOFError: # may have outlived the Transport connection pass - except IOError: + except (IOError, socket.error): # may have outlived the Transport connection pass + def _data_in_prefetch_requests(self, offset, size): + k = [i for i in self._prefetch_reads if i[0] <= offset] + if len(k) == 0: + return False + k.sort(lambda x, y: cmp(x[0], y[0])) + buf_offset, buf_size = k[-1] + if buf_offset + buf_size <= offset: + # prefetch request ends before this one begins + return False + if buf_offset + buf_size >= offset + size: + # inclusive + return True + # well, we have part of the request. see if another chunk has the rest. + return self._data_in_prefetch_requests(buf_offset + buf_size, offset + size - buf_offset - buf_size) + + def _data_in_prefetch_buffers(self, offset): + """ + if a block of data is present in the prefetch buffers, at the given + offset, return the offset of the relevant prefetch buffer. otherwise, + return None. this guarantees nothing about the number of bytes + collected in the prefetch buffer so far. + """ + k = [i for i in self._prefetch_data.keys() if i <= offset] + if len(k) == 0: + return None + index = max(k) + buf_offset = offset - index + if buf_offset >= len(self._prefetch_data[index]): + # it's not here + return None + return index + def _read_prefetch(self, size): + """ + read data out of the prefetch buffer, if possible. if the data isn't + in the buffer, return None. otherwise, behaves like a normal read. + """ # while not closed, and haven't fetched past the current position, and haven't reached EOF... - while (self._prefetch_so_far <= self._realpos) and \ - (self._prefetch_so_far < self._prefetch_size) and not self._closed: + while True: + offset = self._data_in_prefetch_buffers(self._realpos) + if offset is not None: + break + if self._prefetch_done or self._closed: + break self.sftp._read_response() - self._check_exception() - k = self._prefetch_data.keys() - k.sort() - while (len(k) > 0) and (k[0] + len(self._prefetch_data[k[0]]) <= self._realpos): - # done with that block - del self._prefetch_data[k[0]] - k.pop(0) - if len(k) == 0: + self._check_exception() + if offset is None: self._prefetching = False - return '' - assert k[0] <= self._realpos - buf_offset = self._realpos - k[0] - buf_length = len(self._prefetch_data[k[0]]) - buf_offset - return self._prefetch_data[k[0]][buf_offset : buf_offset + buf_length] + return None + prefetch = self._prefetch_data[offset] + del self._prefetch_data[offset] + + buf_offset = self._realpos - offset + if buf_offset > 0: + self._prefetch_data[offset] = prefetch[:buf_offset] + prefetch = prefetch[buf_offset:] + if size < len(prefetch): + self._prefetch_data[self._realpos + size] = prefetch[size:] + prefetch = prefetch[:size] + return prefetch def _read(self, size): size = min(size, self.MAX_REQUEST_SIZE) if self._prefetching: - return self._read_prefetch(size) + data = self._read_prefetch(size) + if data is not None: + return data t, msg = self.sftp._request(CMD_READ, self.handle, long(self._realpos), int(size)) if t != CMD_DATA: raise SFTPError('Expected data') @@ -106,8 +160,7 @@ class SFTPFile (BufferedFile): def _write(self, data): # may write less than requested if it would exceed max packet size chunk = min(len(data), self.MAX_REQUEST_SIZE) - req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), - str(data[:chunk])) + req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk])) if not self.pipelined or self.sftp.sock.recv_ready(): t, msg = self.sftp._read_response(req) if t != CMD_STATUS: @@ -173,6 +226,71 @@ class SFTPFile (BufferedFile): if t != CMD_ATTRS: raise SFTPError('Expected attributes') return SFTPAttributes._from_msg(msg) + + def chmod(self, mode): + """ + Change the mode (permissions) of this file. The permissions are + unix-style and identical to those used by python's C{os.chmod} + function. + + @param mode: new permissions + @type mode: int + """ + self.sftp._log(DEBUG, 'chmod(%s, %r)' % (hexlify(self.handle), mode)) + attr = SFTPAttributes() + attr.st_mode = mode + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def chown(self, uid, gid): + """ + Change the owner (C{uid}) and group (C{gid}) of this file. As with + python's C{os.chown} function, you must pass both arguments, so if you + only want to change one, use L{stat} first to retrieve the current + owner and group. + + @param uid: new owner's uid + @type uid: int + @param gid: new group id + @type gid: int + """ + self.sftp._log(DEBUG, 'chown(%s, %r, %r)' % (hexlify(self.handle), uid, gid)) + attr = SFTPAttributes() + attr.st_uid, attr.st_gid = uid, gid + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def utime(self, times): + """ + Set the access and modified times of this file. If + C{times} is C{None}, then the file's access and modified times are set + to the current time. Otherwise, C{times} must be a 2-tuple of numbers, + of the form C{(atime, mtime)}, which is used to set the access and + modified times, respectively. This bizarre API is mimicked from python + for the sake of consistency -- I apologize. + + @param times: C{None} or a tuple of (access time, modified time) in + standard internet epoch time (seconds since 01 January 1970 GMT) + @type times: tuple(int) + """ + if times is None: + times = (time.time(), time.time()) + self.sftp._log(DEBUG, 'utime(%s, %r)' % (hexlify(self.handle), times)) + attr = SFTPAttributes() + attr.st_atime, attr.st_mtime = times + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def truncate(self, size): + """ + Change the size of this file. This usually extends + or shrinks the size of the file, just like the C{truncate()} method on + python file objects. + + @param size: the new size of the file + @type size: int or long + """ + self.sftp._log(DEBUG, 'truncate(%s, %r)' % (hexlify(self.handle), size)) + attr = SFTPAttributes() + attr.st_size = size + self.sftp._request(CMD_FSETSTAT, self.handle, attr) def check(self, hash_algorithm, offset=0, length=0, block_size=0): """ @@ -255,26 +373,60 @@ class SFTPFile (BufferedFile): dramatically improve the download speed by avoiding roundtrip latency. The file's contents are incrementally buffered in a background thread. + The prefetched data is stored in a buffer until read via the L{read} + method. Once data has been read, it's removed from the buffer. The + data may be read in a random order (using L{seek}); chunks of the + buffer that haven't been read will continue to be buffered. + @since: 1.5.1 """ size = self.stat().st_size # queue up async reads for the rest of the file - self._prefetching = True - self._prefetch_so_far = self._realpos - self._prefetch_size = size - self._prefetch_data = {} - t = threading.Thread(target=self._prefetch) - t.setDaemon(True) - t.start() - - def _prefetch(self): + chunks = [] n = self._realpos - size = self._prefetch_size while n < size: chunk = min(self.MAX_REQUEST_SIZE, size - n) - self.sftp._async_request(self, CMD_READ, self.handle, long(n), int(chunk)) + chunks.append((n, chunk)) n += chunk + if len(chunks) > 0: + self._start_prefetch(chunks) + + def readv(self, chunks): + """ + Read a set of blocks from the file by (offset, length). This is more + efficient than doing a series of L{seek} and L{read} calls, since the + prefetch machinery is used to retrieve all the requested blocks at + once. + + @param chunks: a list of (offset, length) tuples indicating which + sections of the file to read + @type chunks: list(tuple(long, int)) + @return: a list of blocks read, in the same order as in C{chunks} + @rtype: list(str) + + @since: 1.5.4 + """ + self.sftp._log(DEBUG, 'readv(%s, %r)' % (hexlify(self.handle), chunks)) + read_chunks = [] + for offset, size in chunks: + # don't fetch data that's already in the prefetch buffer + if self._data_in_prefetch_buffers(offset) or self._data_in_prefetch_requests(offset, size): + continue + + # break up anything larger than the max read size + while size > 0: + chunk_size = min(size, self.MAX_REQUEST_SIZE) + read_chunks.append((offset, chunk_size)) + offset += chunk_size + size -= chunk_size + + self._start_prefetch(read_chunks) + # now we can just devolve to a bunch of read()s :) + for x in chunks: + self.seek(x[0]) + yield self.read(x[1]) + ### internals... @@ -285,6 +437,21 @@ class SFTPFile (BufferedFile): except: return 0 + def _start_prefetch(self, chunks): + self._prefetching = True + self._prefetch_done = False + self._prefetch_reads.extend(chunks) + + t = threading.Thread(target=self._prefetch_thread, args=(chunks,)) + t.setDaemon(True) + t.start() + + def _prefetch_thread(self, chunks): + # do these read requests in a temporary thread because there may be + # a lot of them, so it may block. + for offset, length in chunks: + self.sftp._async_request(self, CMD_READ, self.handle, long(offset), int(length)) + def _async_response(self, t, msg): if t == CMD_STATUS: # save exception and re-raise it on next file operation @@ -296,8 +463,10 @@ class SFTPFile (BufferedFile): if t != CMD_DATA: raise SFTPError('Expected data') data = msg.get_string() - self._prefetch_data[self._prefetch_so_far] = data - self._prefetch_so_far += len(data) + offset, length = self._prefetch_reads.pop(0) + self._prefetch_data[offset] = data + if len(self._prefetch_reads) == 0: + self._prefetch_done = True def _check_exception(self): "if there's a saved exception, raise & clear it" diff --git a/paramiko/sftp_handle.py b/paramiko/sftp_handle.py index e1d93e9..e976f43 100644 --- a/paramiko/sftp_handle.py +++ b/paramiko/sftp_handle.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -35,7 +35,16 @@ class SFTPHandle (object): Server implementations can (and should) subclass SFTPHandle to implement features of a file handle, like L{stat} or L{chattr}. """ - def __init__(self): + def __init__(self, flags=0): + """ + Create a new file handle representing a local file being served over + SFTP. If C{flags} is passed in, it's used to determine if the file + is open in append mode. + + @param flags: optional flags as passed to L{SFTPServerInterface.open} + @type flags: int + """ + self.__flags = flags self.__name = None # only for handles to folders: self.__files = { } @@ -81,15 +90,16 @@ class SFTPHandle (object): @return: data read from the file, or an SFTP error code. @rtype: str """ - if not hasattr(self, 'readfile') or (self.readfile is None): + readfile = getattr(self, 'readfile', None) + if readfile is None: return SFTP_OP_UNSUPPORTED try: if self.__tell is None: - self.__tell = self.readfile.tell() + self.__tell = readfile.tell() if offset != self.__tell: - self.readfile.seek(offset) + readfile.seek(offset) self.__tell = offset - data = self.readfile.read(length) + data = readfile.read(length) except IOError, e: self.__tell = None return SFTPServer.convert_errno(e.errno) @@ -116,20 +126,24 @@ class SFTPHandle (object): @type data: str @return: an SFTP error code like L{SFTP_OK}. """ - if not hasattr(self, 'writefile') or (self.writefile is None): + writefile = getattr(self, 'writefile', None) + if writefile is None: return SFTP_OP_UNSUPPORTED try: - if self.__tell is None: - self.__tell = self.writefile.tell() - if offset != self.__tell: - self.writefile.seek(offset) - self.__tell = offset - self.writefile.write(data) - self.writefile.flush() + # in append mode, don't care about seeking + if (self.__flags & os.O_APPEND) == 0: + if self.__tell is None: + self.__tell = writefile.tell() + if offset != self.__tell: + writefile.seek(offset) + self.__tell = offset + writefile.write(data) + writefile.flush() except IOError, e: self.__tell = None return SFTPServer.convert_errno(e.errno) - self.__tell += len(data) + if self.__tell is not None: + self.__tell += len(data) return SFTP_OK def stat(self): diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py index 5905843..099ac12 100644 --- a/paramiko/sftp_server.py +++ b/paramiko/sftp_server.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -66,15 +66,21 @@ class SFTPServer (BaseSFTP, SubsystemHandler): BaseSFTP.__init__(self) SubsystemHandler.__init__(self, channel, name, server) transport = channel.get_transport() - self.logger = util.get_logger(transport.get_log_channel() + '.' + - channel.get_name() + '.sftp') + self.logger = util.get_logger(transport.get_log_channel() + '.sftp') self.ultra_debug = transport.get_hexdump() self.next_handle = 1 # map of handle-string to SFTPHandle for files & folders: self.file_table = { } self.folder_table = { } self.server = sftp_si(server, *largs, **kwargs) - + + def _log(self, level, msg): + if issubclass(type(msg), list): + for m in msg: + super(SFTPServer, self)._log(level, "[chan " + self.sock.get_name() + "] " + m) + else: + super(SFTPServer, self)._log(level, "[chan " + self.sock.get_name() + "] " + msg) + def start_subsystem(self, name, transport, channel): self.sock = channel self._log(DEBUG, 'Started sftp server on channel %s' % repr(channel)) @@ -92,10 +98,20 @@ class SFTPServer (BaseSFTP, SubsystemHandler): return msg = Message(data) request_number = msg.get_int() - self._process(t, request_number, msg) + try: + self._process(t, request_number, msg) + except Exception, e: + self._log(DEBUG, 'Exception in server processing: ' + str(e)) + self._log(DEBUG, util.tb_strings()) + # send some kind of failure message, at least + try: + self._send_status(request_number, SFTP_FAILURE) + except: + pass def finish_subsystem(self): self.server.session_ended() + super(SFTPServer, self).finish_subsystem() # close any file handles that were left open (so we can return them to the OS quickly) for f in self.file_table.itervalues(): f.close() @@ -118,7 +134,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): if e == errno.EACCES: # permission denied return SFTP_PERMISSION_DENIED - elif e == errno.ENOENT: + elif (e == errno.ENOENT) or (e == errno.ENOTDIR): # no such file return SFTP_NO_SUCH_FILE else: @@ -141,12 +157,16 @@ class SFTPServer (BaseSFTP, SubsystemHandler): @param attr: attributes to change. @type attr: L{SFTPAttributes} """ - if attr._flags & attr.FLAG_PERMISSIONS: - os.chmod(filename, attr.st_mode) - if attr._flags & attr.FLAG_UIDGID: - os.chown(filename, attr.st_uid, attr.st_gid) + if sys.platform != 'win32': + # mode operations are meaningless on win32 + if attr._flags & attr.FLAG_PERMISSIONS: + os.chmod(filename, attr.st_mode) + if attr._flags & attr.FLAG_UIDGID: + os.chown(filename, attr.st_uid, attr.st_gid) if attr._flags & attr.FLAG_AMTIME: os.utime(filename, (attr.st_atime, attr.st_mtime)) + if attr._flags & attr.FLAG_SIZE: + open(filename, 'w+').truncate(attr.st_size) set_file_attr = staticmethod(set_file_attr) @@ -184,8 +204,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler): def _send_status(self, request_number, code, desc=None): if desc is None: - desc = SFTP_DESC[code] - self._response(request_number, CMD_STATUS, code, desc) + try: + desc = SFTP_DESC[code] + except IndexError: + desc = 'Unknown' + # some clients expect a "langauge" tag at the end (but don't mind it being blank) + self._response(request_number, CMD_STATUS, code, desc, '') def _open_folder(self, request_number, path): resp = self.server.list_folder(path) @@ -222,7 +246,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): start = msg.get_int64() length = msg.get_int64() block_size = msg.get_int() - if not self.file_table.has_key(handle): + if handle not in self.file_table: self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') return f = self.file_table[handle] @@ -246,29 +270,29 @@ class SFTPServer (BaseSFTP, SubsystemHandler): self._send_status(request_number, SFTP_FAILURE, 'Block size too small') return - sum = '' + sum_out = '' offset = start while offset < start + length: blocklen = min(block_size, start + length - offset) # don't try to read more than about 64KB at a time chunklen = min(blocklen, 65536) count = 0 - hash = alg.new() + hash_obj = alg.new() while count < blocklen: data = f.read(offset, chunklen) if not type(data) is str: self._send_status(request_number, data, 'Unable to hash file') return - hash.update(data) + hash_obj.update(data) count += len(data) offset += count - sum += hash.digest() + sum_out += hash_obj.digest() msg = Message() msg.add_int(request_number) msg.add_string('check-file') msg.add_string(algname) - msg.add_bytes(sum) + msg.add_bytes(sum_out) self._send_packet(CMD_EXTENDED_REPLY, str(msg)) def _convert_pflags(self, pflags): @@ -298,11 +322,11 @@ class SFTPServer (BaseSFTP, SubsystemHandler): self._send_handle_response(request_number, self.server.open(path, flags, attr)) elif t == CMD_CLOSE: handle = msg.get_string() - if self.folder_table.has_key(handle): + if handle in self.folder_table: del self.folder_table[handle] self._send_status(request_number, SFTP_OK) return - if self.file_table.has_key(handle): + if handle in self.file_table: self.file_table[handle].close() del self.file_table[handle] self._send_status(request_number, SFTP_OK) @@ -312,7 +336,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): handle = msg.get_string() offset = msg.get_int64() length = msg.get_int() - if not self.file_table.has_key(handle): + if handle not in self.file_table: self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') return data = self.file_table[handle].read(offset, length) @@ -327,7 +351,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): handle = msg.get_string() offset = msg.get_int64() data = msg.get_string() - if not self.file_table.has_key(handle): + if handle not in self.file_table: self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') return self._send_status(request_number, self.file_table[handle].write(offset, data)) @@ -351,7 +375,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): return elif t == CMD_READDIR: handle = msg.get_string() - if not self.folder_table.has_key(handle): + if handle not in self.folder_table: self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') return folder = self.folder_table[handle] @@ -372,7 +396,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): self._send_status(request_number, resp) elif t == CMD_FSTAT: handle = msg.get_string() - if not self.file_table.has_key(handle): + if handle not in self.file_table: self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') return resp = self.file_table[handle].stat() @@ -387,7 +411,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): elif t == CMD_FSETSTAT: handle = msg.get_string() attr = SFTPAttributes._from_msg(msg) - if not self.file_table.has_key(handle): + if handle not in self.file_table: self._response(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') return self._send_status(request_number, self.file_table[handle].chattr(attr)) @@ -412,7 +436,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): if tag == 'check-file': self._check_file(request_number, msg) else: - send._send_status(request_number, SFTP_OP_UNSUPPORTED) + self._send_status(request_number, SFTP_OP_UNSUPPORTED) else: self._send_status(request_number, SFTP_OP_UNSUPPORTED) diff --git a/paramiko/sftp_si.py b/paramiko/sftp_si.py index 16005d4..47dd25d 100644 --- a/paramiko/sftp_si.py +++ b/paramiko/sftp_si.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -36,6 +36,9 @@ class SFTPServerInterface (object): SFTP sessions). However, raising an exception will usually cause the SFTP session to abruptly end, so you will usually want to catch exceptions and return an appropriate error code. + + All paths are in string form instead of unicode because not all SFTP + clients & servers obey the requirement that paths be encoded in UTF-8. """ def __init__ (self, server, *largs, **kwargs): @@ -268,9 +271,13 @@ class SFTPServerInterface (object): The default implementation returns C{os.path.normpath('/' + path)}. """ if os.path.isabs(path): - return os.path.normpath(path) + out = os.path.normpath(path) else: - return os.path.normpath('/' + path) + out = os.path.normpath('/' + path) + if sys.platform == 'win32': + # on windows, normalize backslashes to sftp/posix format + out = out.replace('\\', '/') + return out def readlink(self, path): """ diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py index 900d4a0..e3120bb 100644 --- a/paramiko/ssh_exception.py +++ b/paramiko/ssh_exception.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -28,14 +28,25 @@ class SSHException (Exception): pass -class PasswordRequiredException (SSHException): +class AuthenticationException (SSHException): + """ + Exception raised when authentication failed for some reason. It may be + possible to retry with different credentials. (Other classes specify more + specific reasons.) + + @since: 1.6 + """ + pass + + +class PasswordRequiredException (AuthenticationException): """ Exception raised when a password is needed to unlock a private key file. """ pass -class BadAuthenticationType (SSHException): +class BadAuthenticationType (AuthenticationException): """ Exception raised when an authentication type (like password) is used, but the server isn't allowing that type. (It may only allow public-key, for @@ -51,19 +62,54 @@ class BadAuthenticationType (SSHException): allowed_types = [] def __init__(self, explanation, types): - SSHException.__init__(self, explanation) + AuthenticationException.__init__(self, explanation) self.allowed_types = types def __str__(self): return SSHException.__str__(self) + ' (allowed_types=%r)' % self.allowed_types -class PartialAuthentication (SSHException): +class PartialAuthentication (AuthenticationException): """ An internal exception thrown in the case of partial authentication. """ allowed_types = [] def __init__(self, types): - SSHException.__init__(self, 'partial authentication') + AuthenticationException.__init__(self, 'partial authentication') self.allowed_types = types + + +class ChannelException (SSHException): + """ + Exception raised when an attempt to open a new L{Channel} fails. + + @ivar code: the error code returned by the server + @type code: int + + @since: 1.6 + """ + def __init__(self, code, text): + SSHException.__init__(self, text) + self.code = code + + +class BadHostKeyException (SSHException): + """ + The host key given by the SSH server did not match what we were expecting. + + @ivar hostname: the hostname of the SSH server + @type hostname: str + @ivar key: the host key presented by the server + @type key: L{PKey} + @ivar expected_key: the host key expected + @type expected_key: L{PKey} + + @since: 1.6 + """ + def __init__(self, hostname, got_key, expected_key): + SSHException.__init__(self, 'Host key for server %s does not match!' % hostname) + self.hostname = hostname + self.key = got_key + self.expected_key = expected_key + diff --git a/paramiko/transport.py b/paramiko/transport.py index 8714a96..a18e05b 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -30,19 +30,20 @@ import time import weakref from paramiko import util +from paramiko.auth_handler import AuthHandler +from paramiko.channel import Channel from paramiko.common import * from paramiko.compress import ZlibCompressor, ZlibDecompressor -from paramiko.ssh_exception import SSHException, BadAuthenticationType -from paramiko.message import Message -from paramiko.channel import Channel -from paramiko.sftp_client import SFTPClient -from paramiko.packet import Packetizer, NeedRekeyException -from paramiko.rsakey import RSAKey from paramiko.dsskey import DSSKey -from paramiko.kex_group1 import KexGroup1 from paramiko.kex_gex import KexGex +from paramiko.kex_group1 import KexGroup1 +from paramiko.message import Message +from paramiko.packet import Packetizer, NeedRekeyException from paramiko.primes import ModulusPack -from paramiko.auth_handler import AuthHandler +from paramiko.rsakey import RSAKey +from paramiko.server import ServerInterface +from paramiko.sftp_client import SFTPClient +from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException # these come from PyCrypt # http://www.amk.ca/python/writing/pycrypt/ @@ -50,7 +51,7 @@ from paramiko.auth_handler import AuthHandler # PyCrypt compiled for Win32 can be downloaded from the HashTar homepage: # http://nitace.bsd.uchicago.edu:8080/hashtar from Crypto.Cipher import Blowfish, AES, DES3 -from Crypto.Hash import SHA, MD5, HMAC +from Crypto.Hash import SHA, MD5 # for thread cleanup @@ -73,8 +74,6 @@ class SecurityOptions (object): If you try to add an algorithm that paramiko doesn't recognize, C{ValueError} will be raised. If you try to assign something besides a tuple to one of the fields, C{TypeError} will be raised. - - @since: ivysaur """ __slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ] @@ -110,7 +109,8 @@ class SecurityOptions (object): if type(x) is not tuple: raise TypeError('expected tuple or list') possible = getattr(self._transport, orig).keys() - if len(filter(lambda n: n not in possible, x)) > 0: + forbidden = filter(lambda n: n not in possible, x) + if len(forbidden) > 0: raise ValueError('unknown cipher') setattr(self._transport, name, x) @@ -140,6 +140,51 @@ class SecurityOptions (object): "Compression algorithms") +class ChannelMap (object): + def __init__(self): + # (id -> Channel) + self._map = weakref.WeakValueDictionary() + self._lock = threading.Lock() + + def put(self, chanid, chan): + self._lock.acquire() + try: + self._map[chanid] = chan + finally: + self._lock.release() + + def get(self, chanid): + self._lock.acquire() + try: + return self._map.get(chanid, None) + finally: + self._lock.release() + + def delete(self, chanid): + self._lock.acquire() + try: + try: + del self._map[chanid] + except KeyError: + pass + finally: + self._lock.release() + + def values(self): + self._lock.acquire() + try: + return self._map.values() + finally: + self._lock.release() + + def __len__(self): + self._lock.acquire() + try: + return len(self._map) + finally: + self._lock.release() + + class Transport (threading.Thread): """ An SSH Transport attaches to a stream (usually a socket), negotiates an @@ -149,7 +194,7 @@ class Transport (threading.Thread): """ _PROTO_ID = '2.0' - _CLIENT_ID = 'paramiko_1.5.2' + _CLIENT_ID = 'paramiko_1.7.4' _preferred_ciphers = ( 'aes128-cbc', 'blowfish-cbc', 'aes256-cbc', '3des-cbc' ) _preferred_macs = ( 'hmac-sha1', 'hmac-md5', 'hmac-sha1-96', 'hmac-md5-96' ) @@ -245,25 +290,41 @@ class Transport (threading.Thread): self.sock.settimeout(0.1) except AttributeError: pass + # negotiated crypto parameters self.packetizer = Packetizer(sock) self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID self.remote_version = '' self.local_cipher = self.remote_cipher = '' self.local_kex_init = self.remote_kex_init = None + self.local_mac = self.remote_mac = None + self.local_compression = self.remote_compression = None self.session_id = None - # /negotiated crypto parameters - self.expected_packet = 0 + self.host_key_type = None + self.host_key = None + + # state used during negotiation + self.kex_engine = None + self.H = None + self.K = None + self.active = False self.initial_kex_done = False self.in_kex = False + self.authenticated = False + self._expected_packet = tuple() self.lock = threading.Lock() # synchronization (always higher level than write_lock) - self.channels = weakref.WeakValueDictionary() # (id -> Channel) + + # tracking open channels + self._channels = ChannelMap() self.channel_events = { } # (id -> Event) self.channels_seen = { } # (id -> True) - self.channel_counter = 1 + self._channel_counter = 1 self.window_size = 65536 self.max_packet_size = 34816 + self._x11_handler = None + self._tcp_handler = None + self.saved_exception = None self.clear_to_send = threading.Event() self.clear_to_send_lock = threading.Lock() @@ -271,9 +332,10 @@ class Transport (threading.Thread): self.logger = util.get_logger(self.log_name) self.packetizer.set_log(self.logger) self.auth_handler = None - self.authenticated = False - # user-defined event callbacks: - self.completion_event = None + self.global_response = None # response Message from an arbitrary global request + self.completion_event = None # user-defined event callbacks + self.banner_timeout = 15 # how long (seconds) to wait for the SSH banner + # server mode: self.server_mode = False self.server_object = None @@ -282,9 +344,6 @@ class Transport (threading.Thread): self.server_accept_cv = threading.Condition(self.lock) self.subsystem_table = { } - def __del__(self): - self.close() - def __repr__(self): """ Returns a string representation of this object, for debugging. @@ -299,16 +358,26 @@ class Transport (threading.Thread): out += ' (cipher %s, %d bits)' % (self.local_cipher, self._cipher_info[self.local_cipher]['key-size'] * 8) if self.is_authenticated(): - if len(self.channels) == 1: - out += ' (active; 1 open channel)' - else: - out += ' (active; %d open channels)' % len(self.channels) + out += ' (active; %d open channel(s))' % len(self._channels) elif self.initial_kex_done: out += ' (connected; awaiting auth)' else: out += ' (connecting)' out += '>' return out + + def atfork(self): + """ + Terminate this Transport without closing the session. On posix + systems, if a Transport is open during process forking, both parent + and child will share the underlying socket, but only one process can + use the connection (without corrupting the session). Use this method + to clean up a Transport object without disrupting the other process. + + @since: 1.5.3 + """ + self.sock.close() + self.close() def get_security_options(self): """ @@ -319,8 +388,6 @@ class Transport (threading.Thread): @return: an object that can be used to change the preferred algorithms for encryption, digest (hash), public key, and key exchange. @rtype: L{SecurityOptions} - - @since: ivysaur """ return SecurityOptions(self) @@ -471,7 +538,8 @@ class Transport (threading.Thread): try: return self.server_key_dict[self.host_key_type] except KeyError: - return None + pass + return None def load_server_moduli(filename=None): """ @@ -496,8 +564,6 @@ class Transport (threading.Thread): @return: True if a moduli file was successfully loaded; False otherwise. @rtype: bool - - @since: doduo @note: This has no effect when used in client mode. """ @@ -521,14 +587,13 @@ class Transport (threading.Thread): """ Close this session, and any open channels that are tied to it. """ + if not self.active: + return self.active = False - # since this may be called from __del__, can't assume any attributes exist - try: - self.packetizer.close() - for chan in self.channels.values(): - chan._unlink() - except AttributeError: - pass + self.packetizer.close() + self.join() + for chan in self._channels.values(): + chan._unlink() def get_remote_server_key(self): """ @@ -541,7 +606,7 @@ class Transport (threading.Thread): @raise SSHException: if no session is currently active. - @return: public key of the remote server. + @return: public key of the remote server @rtype: L{PKey <pkey.PKey>} """ if (not self.active) or (not self.initial_kex_done): @@ -553,7 +618,7 @@ class Transport (threading.Thread): Return true if this session is active (open). @return: True if the session is still active (open); False if the - session is closed. + session is closed @rtype: bool """ return self.active @@ -563,12 +628,43 @@ class Transport (threading.Thread): Request a new channel to the server, of type C{"session"}. This is just an alias for C{open_channel('session')}. - @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + @return: a new L{Channel} @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely """ return self.open_channel('session') + def open_x11_channel(self, src_addr=None): + """ + Request a new channel to the client, of type C{"x11"}. This + is just an alias for C{open_channel('x11', src_addr=src_addr)}. + + @param src_addr: the source address of the x11 server (port is the + x11 port, ie. 6010) + @type src_addr: (str, int) + @return: a new L{Channel} + @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely + """ + return self.open_channel('x11', src_addr=src_addr) + + def open_forwarded_tcpip_channel(self, (src_addr, src_port), (dest_addr, dest_port)): + """ + Request a new channel back to the client, of type C{"forwarded-tcpip"}. + This is used after a client has requested port forwarding, for sending + incoming connections back to the client. + + @param src_addr: originator's address + @param src_port: originator's port + @param dest_addr: local (server) connected address + @param dest_port: local (server) connected port + """ + return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port)) + def open_channel(self, kind, dest_addr=None, src_addr=None): """ Request a new channel to the server. L{Channel}s are socket-like @@ -577,18 +673,20 @@ class Transport (threading.Thread): L{connect} or L{start_client}) and authenticating. @param kind: the kind of channel requested (usually C{"session"}, - C{"forwarded-tcpip"} or C{"direct-tcpip"}). + C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"}) @type kind: str @param dest_addr: the destination address of this port forwarding, if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored - for other channel types). + for other channel types) @type dest_addr: (str, int) @param src_addr: the source address of this port forwarding, if - C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}. + C{kind} is C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"} @type src_addr: (str, int) - @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + @return: a new L{Channel} on success @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely """ chan = None if not self.active: @@ -596,11 +694,7 @@ class Transport (threading.Thread): return None self.lock.acquire() try: - chanid = self.channel_counter - while self.channels.has_key(chanid): - self.channel_counter = (self.channel_counter + 1) & 0xffffff - chanid = self.channel_counter - self.channel_counter = (self.channel_counter + 1) & 0xffffff + chanid = self._next_channel() m = Message() m.add_byte(chr(MSG_CHANNEL_OPEN)) m.add_string(kind) @@ -612,7 +706,11 @@ class Transport (threading.Thread): m.add_int(dest_addr[1]) m.add_string(src_addr[0]) m.add_int(src_addr[1]) - self.channels[chanid] = chan = Channel(chanid) + elif kind == 'x11': + m.add_string(src_addr[0]) + m.add_int(src_addr[1]) + chan = Channel(chanid) + self._channels.put(chanid, chan) self.channel_events[chanid] = event = threading.Event() self.channels_seen[chanid] = True chan._set_transport(self) @@ -620,20 +718,84 @@ class Transport (threading.Thread): finally: self.lock.release() self._send_user_message(m) - while 1: + while True: event.wait(0.1); if not self.active: - return None + e = self.get_exception() + if e is None: + e = SSHException('Unable to open channel.') + raise e if event.isSet(): break - try: - self.lock.acquire() - if not self.channels.has_key(chanid): - chan = None - finally: - self.lock.release() - return chan - + chan = self._channels.get(chanid) + if chan is not None: + return chan + e = self.get_exception() + if e is None: + e = SSHException('Unable to open channel.') + raise e + + def request_port_forward(self, address, port, handler=None): + """ + Ask the server to forward TCP connections from a listening port on + the server, across this SSH session. + + If a handler is given, that handler is called from a different thread + whenever a forwarded connection arrives. The handler parameters are:: + + handler(channel, (origin_addr, origin_port), (server_addr, server_port)) + + where C{server_addr} and C{server_port} are the address and port that + the server was listening on. + + If no handler is set, the default behavior is to send new incoming + forwarded connections into the accept queue, to be picked up via + L{accept}. + + @param address: the address to bind when forwarding + @type address: str + @param port: the port to forward, or 0 to ask the server to allocate + any port + @type port: int + @param handler: optional handler for incoming forwarded connections + @type handler: function(Channel, (str, int), (str, int)) + @return: the port # allocated by the server + @rtype: int + + @raise SSHException: if the server refused the TCP forward request + """ + if not self.active: + raise SSHException('SSH session not active') + address = str(address) + port = int(port) + response = self.global_request('tcpip-forward', (address, port), wait=True) + if response is None: + raise SSHException('TCP forwarding request denied') + if port == 0: + port = response.get_int() + if handler is None: + def default_handler(channel, (src_addr, src_port), (dest_addr, dest_port)): + self._queue_incoming_channel(channel) + handler = default_handler + self._tcp_handler = handler + return port + + def cancel_port_forward(self, address, port): + """ + Ask the server to cancel a previous port-forwarding request. No more + connections to the given address & port will be forwarded across this + ssh connection. + + @param address: the address to stop forwarding + @type address: str + @param port: the port to stop forwarding + @type port: int + """ + if not self.active: + return + self._tcp_handler = None + self.global_request('cancel-tcpip-forward', (address, port), wait=True) + def open_sftp_client(self): """ Create an SFTP client channel from an open transport. On success, @@ -656,8 +818,6 @@ class Transport (threading.Thread): @param bytes: the number of random bytes to send in the payload of the ignored packet -- defaults to a random number from 10 to 41. @type bytes: int - - @since: fearow """ m = Message() m.add_byte(chr(MSG_IGNORE)) @@ -674,22 +834,23 @@ class Transport (threading.Thread): bytes sent or received, but this method gives you the option of forcing new keys whenever you want. Negotiating new keys causes a pause in traffic both ways as the two sides swap keys and do computations. This - method returns when the session has switched to new keys, or the - session has died mid-negotiation. + method returns when the session has switched to new keys. - @return: True if the renegotiation was successful, and the link is - using new keys; False if the session dropped during renegotiation. - @rtype: bool + @raise SSHException: if the key renegotiation failed (which causes the + session to end) """ self.completion_event = threading.Event() self._send_kex_init() - while 1: - self.completion_event.wait(0.1); + while True: + self.completion_event.wait(0.1) if not self.active: - return False + e = self.get_exception() + if e is not None: + raise e + raise SSHException('Negotiation failed.') if self.completion_event.isSet(): break - return True + return def set_keepalive(self, interval): """ @@ -701,11 +862,9 @@ class Transport (threading.Thread): @param interval: seconds to wait before sending a keepalive packet (or 0 to disable keepalives). @type interval: int - - @since: fearow """ self.packetizer.set_keepalive(interval, - lambda x=self: x.global_request('keepalive@lag.net', wait=False)) + lambda x=weakref.proxy(self): x.global_request('keepalive@lag.net', wait=False)) def global_request(self, kind, data=None, wait=True): """ @@ -724,8 +883,6 @@ class Transport (threading.Thread): request was successful (or an empty L{Message} if C{wait} was C{False}); C{None} if the request was denied. @rtype: L{Message} - - @since: fearow """ if wait: self.completion_event = threading.Event() @@ -807,8 +964,6 @@ class Transport (threading.Thread): @raise SSHException: if the SSH2 negotiation fails, the host key supplied by the server is incorrect, or authentication fails. - - @since: doduo """ if hostkey is not None: self._preferred_keys = [ hostkey.get_name() ] @@ -896,8 +1051,6 @@ class Transport (threading.Thread): @return: username that was authenticated, or C{None}. @rtype: string - - @since: fearow """ if not self.active or (self.auth_handler is None): return None @@ -958,9 +1111,9 @@ class Transport (threading.Thread): step. Otherwise, in the normal case, an empty list is returned. @param username: the username to authenticate as - @type username: string + @type username: str @param password: the password to authenticate with - @type password: string + @type password: str or unicode @param event: an event to trigger when the authentication attempt is complete (whether it was successful or not) @type event: threading.Event @@ -974,8 +1127,9 @@ class Transport (threading.Thread): @raise BadAuthenticationType: if password authentication isn't allowed by the server for this user (and no event was passed in) - @raise SSHException: if the authentication failed (and no event was - passed in) + @raise AuthenticationException: if the authentication failed (and no + event was passed in) + @raise SSHException: if there was a network error """ if (not self.active) or (not self.initial_kex_done): # we should never try to send the password unless we're on a secure link @@ -993,7 +1147,7 @@ class Transport (threading.Thread): return self.auth_handler.wait_for_response(my_event) except BadAuthenticationType, x: # if password auth isn't allowed, but keyboard-interactive *is*, try to fudge it - if not fallback or not 'keyboard-interactive' in x.allowed_types: + if not fallback or ('keyboard-interactive' not in x.allowed_types): raise try: def handler(title, instructions, fields): @@ -1010,6 +1164,7 @@ class Transport (threading.Thread): except SSHException, ignored: # attempt failed; just raise the original exception raise x + return None def auth_publickey(self, username, key, event=None): """ @@ -1037,13 +1192,14 @@ class Transport (threading.Thread): complete (whether it was successful or not) @type event: threading.Event @return: list of auth types permissible for the next stage of - authentication (normally empty). + authentication (normally empty) @rtype: list @raise BadAuthenticationType: if public-key authentication isn't - allowed by the server for this user (and no event was passed in). - @raise SSHException: if the authentication failed (and no event was - passed in). + allowed by the server for this user (and no event was passed in) + @raise AuthenticationException: if the authentication failed (and no + event was passed in) + @raise SSHException: if there was a network error """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link @@ -1100,7 +1256,8 @@ class Transport (threading.Thread): @raise BadAuthenticationType: if public-key authentication isn't allowed by the server for this user - @raise SSHException: if the authentication failed + @raise AuthenticationException: if the authentication failed + @raise SSHException: if there was a network error @since: 1.5 """ @@ -1119,13 +1276,14 @@ class Transport (threading.Thread): (See the C{logging} module for more info.) SSH Channels will log to a sub-channel of the one specified. - @param name: new channel name for logging. + @param name: new channel name for logging @type name: str @since: 1.1 """ self.log_name = name self.logger = util.get_logger(name) + self.packetizer.set_log(self.logger) def get_log_channel(self): """ @@ -1166,8 +1324,7 @@ class Transport (threading.Thread): """ Turn on/off compression. This will only have an affect before starting the transport (ie before calling L{connect}, etc). By default, - compression is off since it negatively affects interactive sessions - and is not fully tested. + compression is off since it negatively affects interactive sessions. @param compress: C{True} to ask the remote client/server to compress traffic; C{False} to refuse compression @@ -1179,6 +1336,21 @@ class Transport (threading.Thread): self._preferred_compression = ( 'zlib@openssh.com', 'zlib', 'none' ) else: self._preferred_compression = ( 'none', ) + + def getpeername(self): + """ + Return the address of the remote side of this Transport, if possible. + This is effectively a wrapper around C{'getpeername'} on the underlying + socket. If the socket-like object has no C{'getpeername'} method, + then C{("unknown", 0)} is returned. + + @return: the address if the remote host, if known + @rtype: tuple(str, int) + """ + gp = getattr(self.sock, 'getpeername', None) + if gp is None: + return ('unknown', 0) + return gp() def stop_thread(self): self.active = False @@ -1188,25 +1360,29 @@ class Transport (threading.Thread): ### internals... - def _log(self, level, msg): + def _log(self, level, msg, *args): if issubclass(type(msg), list): for m in msg: self.logger.log(level, m) else: - self.logger.log(level, msg) + self.logger.log(level, msg, *args) def _get_modulus_pack(self): "used by KexGex to find primes for group exchange" return self._modulus_pack + def _next_channel(self): + "you are holding the lock" + chanid = self._channel_counter + while self._channels.get(chanid) is not None: + self._channel_counter = (self._channel_counter + 1) & 0xffffff + chanid = self._channel_counter + self._channel_counter = (self._channel_counter + 1) & 0xffffff + return chanid + def _unlink_channel(self, chanid): "used by a Channel to remove itself from the active channel list" - try: - self.lock.acquire() - if self.channels.has_key(chanid): - del self.channels[chanid] - finally: - self.lock.release() + self._channels.delete(chanid) def _send_message(self, data): self.packetizer.send_message(data) @@ -1237,9 +1413,9 @@ class Transport (threading.Thread): if self.session_id == None: self.session_id = h - def _expect_packet(self, type): + def _expect_packet(self, *ptypes): "used by a kex object to register the next packet type it expects to see" - self.expected_packet = type + self._expected_packet = tuple(ptypes) def _verify_key(self, host_key, sig): key = self._key_info[self.host_key_type](Message(host_key)) @@ -1262,16 +1438,34 @@ class Transport (threading.Thread): m.add_mpint(self.K) m.add_bytes(self.H) m.add_bytes(sofar) - hash = SHA.new(str(m)).digest() - out += hash - sofar += hash + digest = SHA.new(str(m)).digest() + out += digest + sofar += digest return out[:nbytes] def _get_cipher(self, name, key, iv): - if not self._cipher_info.has_key(name): + if name not in self._cipher_info: raise SSHException('Unknown client cipher ' + name) return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv) + def _set_x11_handler(self, handler): + # only called if a channel has turned on x11 forwarding + if handler is None: + # by default, use the same mechanism as accept() + def default_handler(channel, (src_addr, src_port)): + self._queue_incoming_channel(channel) + self._x11_handler = default_handler + else: + self._x11_handler = handler + + def _queue_incoming_channel(self, channel): + self.lock.acquire() + try: + self.server_accepts.append(channel) + self.server_accept_cv.notify() + finally: + self.lock.release() + def run(self): # (use the exposed "run" method, because if we specify a thread target # of a private method, threading.Thread will keep a reference to it @@ -1288,7 +1482,7 @@ class Transport (threading.Thread): self.packetizer.write_all(self.local_version + '\r\n') self._check_banner() self._send_kex_init() - self.expected_packet = MSG_KEXINIT + self._expect_packet(MSG_KEXINIT) while self.active: if self.packetizer.need_rekey() and not self.in_kex: @@ -1307,27 +1501,28 @@ class Transport (threading.Thread): elif ptype == MSG_DEBUG: self._parse_debug(m) continue - if self.expected_packet != 0: - if ptype != self.expected_packet: - raise SSHException('Expecting packet %d, got %d' % (self.expected_packet, ptype)) - self.expected_packet = 0 + if len(self._expected_packet) > 0: + if ptype not in self._expected_packet: + raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype)) + self._expected_packet = tuple() if (ptype >= 30) and (ptype <= 39): self.kex_engine.parse_next(ptype, m) continue - if self._handler_table.has_key(ptype): + if ptype in self._handler_table: self._handler_table[ptype](self, m) - elif self._channel_handler_table.has_key(ptype): + elif ptype in self._channel_handler_table: chanid = m.get_int() - if self.channels.has_key(chanid): - self._channel_handler_table[ptype](self.channels[chanid], m) - elif self.channels_seen.has_key(chanid): + chan = self._channels.get(chanid) + if chan is not None: + self._channel_handler_table[ptype](chan, m) + elif chanid in self.channels_seen: self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid) else: self._log(ERROR, 'Channel request for unknown channel %d' % chanid) self.active = False self.packetizer.close() - elif (self.auth_handler is not None) and self.auth_handler._handler_table.has_key(ptype): + elif (self.auth_handler is not None) and (ptype in self.auth_handler._handler_table): self.auth_handler._handler_table[ptype](self.auth_handler, m) else: self._log(WARNING, 'Oops, unhandled type %d' % ptype) @@ -1355,7 +1550,7 @@ class Transport (threading.Thread): self._log(ERROR, util.tb_strings()) self.saved_exception = e _active_threads.remove(self) - for chan in self.channels.values(): + for chan in self._channels.values(): chan._unlink() if self.active: self.active = False @@ -1366,6 +1561,11 @@ class Transport (threading.Thread): self.auth_handler.abort() for event in self.channel_events.values(): event.set() + try: + self.lock.acquire() + self.server_accept_cv.notify() + finally: + self.lock.release() self.sock.close() @@ -1388,30 +1588,31 @@ class Transport (threading.Thread): def _check_banner(self): # this is slow, but we only have to do it once for i in range(5): - # give them 5 seconds for the first line, then just 2 seconds each additional line + # give them 15 seconds for the first line, then just 2 seconds + # each additional line. (some sites have very high latency.) if i == 0: - timeout = 5 + timeout = self.banner_timeout else: timeout = 2 try: - buffer = self.packetizer.readline(timeout) + buf = self.packetizer.readline(timeout) except Exception, x: raise SSHException('Error reading SSH protocol banner' + str(x)) - if buffer[:4] == 'SSH-': + if buf[:4] == 'SSH-': break - self._log(DEBUG, 'Banner: ' + buffer) - if buffer[:4] != 'SSH-': - raise SSHException('Indecipherable protocol version "' + buffer + '"') + self._log(DEBUG, 'Banner: ' + buf) + if buf[:4] != 'SSH-': + raise SSHException('Indecipherable protocol version "' + buf + '"') # save this server version string for later - self.remote_version = buffer + self.remote_version = buf # pull off any attached comment comment = '' - i = string.find(buffer, ' ') + i = string.find(buf, ' ') if i >= 0: - comment = buffer[i+1:] - buffer = buffer[:i] + comment = buf[i+1:] + buf = buf[:i] # parse out version string and make sure it matches - segs = buffer.split('-', 2) + segs = buf.split('-', 2) if len(segs) < 3: raise SSHException('Invalid SSH banner') version = segs[1] @@ -1612,7 +1813,7 @@ class Transport (threading.Thread): if not self.packetizer.need_rekey(): self.in_kex = False # we always expect to receive NEWKEYS now - self.expected_packet = MSG_NEWKEYS + self._expect_packet(MSG_NEWKEYS) def _auth_trigger(self): self.authenticated = True @@ -1661,7 +1862,22 @@ class Transport (threading.Thread): kind = m.get_string() self._log(DEBUG, 'Received global request "%s"' % kind) want_reply = m.get_boolean() - ok = self.server_object.check_global_request(kind, m) + if not self.server_mode: + self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind) + ok = False + elif kind == 'tcpip-forward': + address = m.get_string() + port = m.get_int() + ok = self.server_object.check_port_forward_request(address, port) + if ok != False: + ok = (ok,) + elif kind == 'cancel-tcpip-forward': + address = m.get_string() + port = m.get_int() + self.server_object.cancel_port_forward_request(address, port) + ok = True + else: + ok = self.server_object.check_global_request(kind, m) extra = () if type(ok) is tuple: extra = ok @@ -1692,15 +1908,15 @@ class Transport (threading.Thread): server_chanid = m.get_int() server_window_size = m.get_int() server_max_packet_size = m.get_int() - if not self.channels.has_key(chanid): + chan = self._channels.get(chanid) + if chan is None: self._log(WARNING, 'Success for unrequested channel! [??]') return self.lock.acquire() try: - chan = self.channels[chanid] chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size) self._log(INFO, 'Secsh channel %d opened.' % chanid) - if self.channel_events.has_key(chanid): + if chanid in self.channel_events: self.channel_events[chanid].set() del self.channel_events[chanid] finally: @@ -1712,16 +1928,14 @@ class Transport (threading.Thread): reason = m.get_int() reason_str = m.get_string() lang = m.get_string() - if CONNECTION_FAILED_CODE.has_key(reason): - reason_text = CONNECTION_FAILED_CODE[reason] - else: - reason_text = '(unknown code)' + reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)') self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) + self.lock.acquire() try: - self.lock.aquire() - if self.channels.has_key(chanid): - del self.channels[chanid] - if self.channel_events.has_key(chanid): + self.saved_exception = ChannelException(reason, reason_text) + if chanid in self.channel_events: + self._channels.delete(chanid) + if chanid in self.channel_events: self.channel_events[chanid].set() del self.channel_events[chanid] finally: @@ -1734,21 +1948,47 @@ class Transport (threading.Thread): initial_window_size = m.get_int() max_packet_size = m.get_int() reject = False - if not self.server_mode: + if (kind == 'x11') and (self._x11_handler is not None): + origin_addr = m.get_string() + origin_port = m.get_int() + self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port)) + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None): + server_addr = m.get_string() + server_port = m.get_int() + origin_addr = m.get_string() + origin_port = m.get_int() + self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port)) + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + elif not self.server_mode: self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) reject = True reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: self.lock.acquire() try: - my_chanid = self.channel_counter - while self.channels.has_key(my_chanid): - self.channel_counter = (self.channel_counter + 1) & 0xffffff - my_chanid = self.channel_counter - self.channel_counter = (self.channel_counter + 1) & 0xffffff + my_chanid = self._next_channel() finally: self.lock.release() - reason = self.server_object.check_channel_request(kind, my_chanid) + if kind == 'direct-tcpip': + # handle direct-tcpip requests comming from the client + dest_addr = m.get_string() + dest_port = m.get_int() + origin_addr = m.get_string() + origin_port = m.get_int() + reason = self.server_object.check_channel_direct_tcpip_request( + my_chanid, (origin_addr, origin_port), + (dest_addr, dest_port)) + else: + reason = self.server_object.check_channel_request(kind, my_chanid) if reason != OPEN_SUCCEEDED: self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) reject = True @@ -1761,10 +2001,11 @@ class Transport (threading.Thread): msg.add_string('en') self._send_message(msg) return + chan = Channel(my_chanid) + self.lock.acquire() try: - self.lock.acquire() - self.channels[my_chanid] = chan + self._channels.put(my_chanid, chan) self.channels_seen[my_chanid] = True chan._set_transport(self) chan._set_window(self.window_size, self.max_packet_size) @@ -1778,13 +2019,14 @@ class Transport (threading.Thread): m.add_int(self.window_size) m.add_int(self.max_packet_size) self._send_message(m) - self._log(INFO, 'Secsh channel %d opened.' % my_chanid) - try: - self.lock.acquire() - self.server_accepts.append(chan) - self.server_accept_cv.notify() - finally: - self.lock.release() + self._log(INFO, 'Secsh channel %d (%s) opened.', my_chanid, kind) + if kind == 'x11': + self._x11_handler(chan, (origin_addr, origin_port)) + elif kind == 'forwarded-tcpip': + chan.origin_addr = (origin_addr, origin_port) + self._tcp_handler(chan, (origin_addr, origin_port), (server_addr, server_port)) + else: + self._queue_incoming_channel(chan) def _parse_debug(self, m): always_display = m.get_boolean() @@ -1795,7 +2037,7 @@ class Transport (threading.Thread): def _get_subsystem_handler(self, name): try: self.lock.acquire() - if not self.subsystem_table.has_key(name): + if name not in self.subsystem_table: return (None, [], {}) return self.subsystem_table[name] finally: diff --git a/paramiko/util.py b/paramiko/util.py index abab825..8abdc0c 100644 --- a/paramiko/util.py +++ b/paramiko/util.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -22,13 +22,14 @@ Useful functions used by the rest of paramiko. from __future__ import generators -import fnmatch +from binascii import hexlify, unhexlify import sys import struct import traceback import threading from paramiko.common import * +from paramiko.config import SSHConfig # Change by RogerB - python < 2.3 doesn't have enumerate so we implement it @@ -115,12 +116,10 @@ def format_binary_line(data): return '%-50s %s' % (left, right) def hexify(s): - "turn a string into a hex sequence" - return ''.join(['%02X' % ord(c) for c in s]) + return hexlify(s).upper() def unhexify(s): - "turn a hex sequence back into a string" - return ''.join([chr(int(s[i:i+2], 16)) for i in range(0, len(s), 2)]) + return unhexlify(s) def safe_string(s): out = '' @@ -168,12 +167,12 @@ def generate_key_bytes(hashclass, salt, key, nbytes): if len(salt) > 8: salt = salt[:8] while nbytes > 0: - hash = hashclass.new() + hash_obj = hashclass.new() if len(digest) > 0: - hash.update(digest) - hash.update(key) - hash.update(salt) - digest = hash.digest() + hash_obj.update(digest) + hash_obj.update(key) + hash_obj.update(salt) + digest = hash_obj.digest() size = min(nbytes, len(digest)) keydata += digest[:size] nbytes -= size @@ -189,117 +188,29 @@ def load_host_keys(filename): This type of file unfortunately doesn't exist on Windows, but on posix, it will usually be stored in C{os.path.expanduser("~/.ssh/known_hosts")}. + Since 1.5.3, this is just a wrapper around L{HostKeys}. + @param filename: name of the file to read host keys from @type filename: str @return: dict of host keys, indexed by hostname and then keytype @rtype: dict(hostname, dict(keytype, L{PKey <paramiko.pkey.PKey>})) """ - import base64 - from rsakey import RSAKey - from dsskey import DSSKey - - keys = {} - f = file(filename, 'r') - for line in f: - line = line.strip() - if (len(line) == 0) or (line[0] == '#'): - continue - keylist = line.split(' ') - if len(keylist) != 3: - continue - hostlist, keytype, key = keylist - hosts = hostlist.split(',') - for host in hosts: - if not keys.has_key(host): - keys[host] = {} - if keytype == 'ssh-rsa': - keys[host][keytype] = RSAKey(data=base64.decodestring(key)) - elif keytype == 'ssh-dss': - keys[host][keytype] = DSSKey(data=base64.decodestring(key)) - f.close() - return keys + from paramiko.hostkeys import HostKeys + return HostKeys(filename) def parse_ssh_config(file_obj): """ - Parse a config file of the format used by OpenSSH, and return an object - that can be used to make queries to L{lookup_ssh_host_config}. The - format is described in OpenSSH's C{ssh_config} man page. This method is - provided primarily as a convenience to posix users (since the OpenSSH - format is a de-facto standard on posix) but should work fine on Windows - too. - - The return value is currently a list of dictionaries, each containing - host-specific configuration, but this is considered an implementation - detail and may be subject to change in later versions. - - @param file_obj: a file-like object to read the config file from - @type file_obj: file - @return: opaque configuration object - @rtype: object + Provided only as a backward-compatible wrapper around L{SSHConfig}. """ - ret = [] - config = { 'host': '*' } - ret.append(config) - - for line in file_obj: - line = line.rstrip('\n').lstrip() - if (line == '') or (line[0] == '#'): - continue - if '=' in line: - key, value = line.split('=', 1) - key = key.strip().lower() - else: - # find first whitespace, and split there - i = 0 - while (i < len(line)) and not line[i].isspace(): - i += 1 - if i == len(line): - raise Exception('Unparsable line: %r' % line) - key = line[:i].lower() - value = line[i:].lstrip() - - if key == 'host': - # do we have a pre-existing host config to append to? - matches = [c for c in ret if c['host'] == value] - if len(matches) > 0: - config = matches[0] - else: - config = { 'host': value } - ret.append(config) - else: - config[key] = value - - return ret + config = SSHConfig() + config.parse(file_obj) + return config def lookup_ssh_host_config(hostname, config): """ - Return a dict of config options for a given hostname. The C{config} object - must come from L{parse_ssh_config}. - - The host-matching rules of OpenSSH's C{ssh_config} man page are used, which - means that all configuration options from matching host specifications are - merged, with more specific hostmasks taking precedence. In other words, if - C{"Port"} is set under C{"Host *"} and also C{"Host *.example.com"}, and - the lookup is for C{"ssh.example.com"}, then the port entry for - C{"Host *.example.com"} will win out. - - The keys in the returned dict are all normalized to lowercase (look for - C{"port"}, not C{"Port"}. No other processing is done to the keys or - values. - - @param hostname: the hostname to lookup - @type hostname: str - @param config: the config object to search - @type config: object + Provided only as a backward-compatible wrapper around L{SSHConfig}. """ - matches = [x for x in config if fnmatch.fnmatch(hostname, x['host'])] - # sort in order of shortest match (usually '*') to longest - matches.sort(key=lambda x: len(x['host'])) - ret = {} - for m in matches: - ret.update(m) - del ret['host'] - return ret + return config.lookup(hostname) def mod_inverse(x, m): # it's crazy how small python can make this function. @@ -355,3 +266,5 @@ def get_logger(name): l = logging.getLogger(name) l.addFilter(_pfilter) return l + + diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py new file mode 100644 index 0000000..787032b --- /dev/null +++ b/paramiko/win_pageant.py @@ -0,0 +1,148 @@ +# Copyright (C) 2005 John Arbash-Meinel <john@arbash-meinel.com> +# Modified up by: Todd Whiteman <ToddW@ActiveState.com> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Functions for communicating with Pageant, the basic windows ssh agent program. +""" + +import os +import struct +import tempfile +import mmap +import array + +# if you're on windows, you should have one of these, i guess? +# ctypes is part of standard library since Python 2.5 +_has_win32all = False +_has_ctypes = False +try: + # win32gui is preferred over win32ui to avoid MFC dependencies + import win32gui + _has_win32all = True +except ImportError: + try: + import ctypes + _has_ctypes = True + except ImportError: + pass + + +_AGENT_COPYDATA_ID = 0x804e50ba +_AGENT_MAX_MSGLEN = 8192 +# Note: The WM_COPYDATA value is pulled from win32con, as a workaround +# so we do not have to import this huge library just for this one variable. +win32con_WM_COPYDATA = 74 + + +def _get_pageant_window_object(): + if _has_win32all: + try: + hwnd = win32gui.FindWindow('Pageant', 'Pageant') + return hwnd + except win32gui.error: + pass + elif _has_ctypes: + # Return 0 if there is no Pageant window. + return ctypes.windll.user32.FindWindowA('Pageant', 'Pageant') + return None + + +def can_talk_to_agent(): + """ + Check to see if there is a "Pageant" agent we can talk to. + + This checks both if we have the required libraries (win32all or ctypes) + and if there is a Pageant currently running. + """ + if (_has_win32all or _has_ctypes) and _get_pageant_window_object(): + return True + return False + + +def _query_pageant(msg): + hwnd = _get_pageant_window_object() + if not hwnd: + # Raise a failure to connect exception, pageant isn't running anymore! + return None + + # Write our pageant request string into the file (pageant will read this to determine what to do) + filename = tempfile.mktemp('.pag') + map_filename = os.path.basename(filename) + + f = open(filename, 'w+b') + f.write(msg ) + # Ensure the rest of the file is empty, otherwise pageant will read this + f.write('\0' * (_AGENT_MAX_MSGLEN - len(msg))) + # Create the shared file map that pageant will use to read from + pymap = mmap.mmap(f.fileno(), _AGENT_MAX_MSGLEN, tagname=map_filename, access=mmap.ACCESS_WRITE) + try: + # Create an array buffer containing the mapped filename + char_buffer = array.array("c", map_filename + '\0') + char_buffer_address, char_buffer_size = char_buffer.buffer_info() + # Create a string to use for the SendMessage function call + cds = struct.pack("LLP", _AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address) + + if _has_win32all: + # win32gui.SendMessage should also allow the same pattern as + # ctypes, but let's keep it like this for now... + response = win32gui.SendMessage(hwnd, win32con_WM_COPYDATA, len(cds), cds) + elif _has_ctypes: + _buf = array.array('B', cds) + _addr, _size = _buf.buffer_info() + response = ctypes.windll.user32.SendMessageA(hwnd, win32con_WM_COPYDATA, _size, _addr) + else: + response = 0 + + if response > 0: + datalen = pymap.read(4) + retlen = struct.unpack('>I', datalen)[0] + return datalen + pymap.read(retlen) + return None + finally: + pymap.close() + f.close() + # Remove the file, it was temporary only + os.unlink(filename) + + +class PageantConnection (object): + """ + Mock "connection" to an agent which roughly approximates the behavior of + a unix local-domain socket (as used by Agent). Requests are sent to the + pageant daemon via special Windows magick, and responses are buffered back + for subsequent reads. + """ + + def __init__(self): + self._response = None + + def send(self, data): + self._response = _query_pageant(data) + + def recv(self, n): + if self._response is None: + return '' + ret = self._response[:n] + self._response = self._response[n:] + if self._response == '': + self._response = None + return ret + + def close(self): + pass |