aboutsummaryrefslogtreecommitdiff
path: root/paramiko
diff options
context:
space:
mode:
authorJeremy T. Bouse <jbouse@debian.org>2009-11-27 16:20:12 -0500
committerJeremy T. Bouse <jbouse@debian.org>2009-11-27 16:20:12 -0500
commited280d5ac360e2af796e9bd973d7b4df89f0c449 (patch)
treece892d6ce9dad8c0ecbc9cbe73f8095195bef0b4 /paramiko
parent176c6caf4ea7918e1698438634b237fab8456471 (diff)
downloadpython-paramiko-ed280d5ac360e2af796e9bd973d7b4df89f0c449.tar
python-paramiko-ed280d5ac360e2af796e9bd973d7b4df89f0c449.tar.gz
Imported Upstream version 1.7.4upstream/1.7.4
Diffstat (limited to 'paramiko')
-rw-r--r--paramiko/__init__.py67
-rw-r--r--paramiko/agent.py39
-rw-r--r--paramiko/auth_handler.py42
-rw-r--r--paramiko/ber.py7
-rw-r--r--paramiko/buffered_pipe.py200
-rw-r--r--paramiko/channel.py471
-rw-r--r--paramiko/client.py474
-rw-r--r--paramiko/common.py26
-rw-r--r--paramiko/compress.py2
-rw-r--r--paramiko/config.py105
-rw-r--r--paramiko/dsskey.py37
-rw-r--r--paramiko/file.py82
-rw-r--r--paramiko/hostkeys.py315
-rw-r--r--paramiko/kex_gex.py68
-rw-r--r--paramiko/kex_group1.py2
-rw-r--r--paramiko/logging22.py2
-rw-r--r--paramiko/message.py4
-rw-r--r--paramiko/packet.py116
-rw-r--r--paramiko/pipe.py52
-rw-r--r--paramiko/pkey.py93
-rw-r--r--paramiko/primes.py17
-rw-r--r--paramiko/resource.py72
-rw-r--r--paramiko/rng.py112
-rw-r--r--paramiko/rng_posix.py97
-rw-r--r--paramiko/rng_win32.py121
-rw-r--r--paramiko/rsakey.py41
-rw-r--r--paramiko/server.py111
-rw-r--r--paramiko/sftp.py48
-rw-r--r--paramiko/sftp_attr.py75
-rw-r--r--paramiko/sftp_client.py229
-rw-r--r--paramiko/sftp_file.py241
-rw-r--r--paramiko/sftp_handle.py44
-rw-r--r--paramiko/sftp_server.py76
-rw-r--r--paramiko/sftp_si.py13
-rw-r--r--paramiko/ssh_exception.py58
-rw-r--r--paramiko/transport.py574
-rw-r--r--paramiko/util.py131
-rw-r--r--paramiko/win_pageant.py148
38 files changed, 3501 insertions, 911 deletions
diff --git a/paramiko/__init__.py b/paramiko/__init__.py
index 0a312cb..9a8caec 100644
--- a/paramiko/__init__.py
+++ b/paramiko/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2008 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -26,8 +26,9 @@ replaced C{telnet} and C{rsh} for secure access to remote shells, but the
protocol also includes the ability to open arbitrary channels to remote
services across an encrypted tunnel. (This is how C{sftp} works, for example.)
-To use this package, pass a socket (or socket-like object) to a L{Transport},
-and use L{start_server <Transport.start_server>} or
+The high-level client API starts with creation of an L{SSHClient} object.
+For more direct control, pass a socket (or socket-like object) to a
+L{Transport}, and use L{start_server <Transport.start_server>} or
L{start_client <Transport.start_client>} to negoatite
with the remote host as either a server or client. As a client, you are
responsible for authenticating using a password or private key, and checking
@@ -46,7 +47,7 @@ released under the GNU Lesser General Public License (LGPL).
Website: U{http://www.lag.net/paramiko/}
-@version: 1.5.2 (rhydon)
+@version: 1.7.4 (Desmond)
@author: Robey Pointer
@contact: robey@lag.net
@license: GNU Lesser General Public License (LGPL)
@@ -59,20 +60,19 @@ if sys.version_info < (2, 2):
__author__ = "Robey Pointer <robey@lag.net>"
-__date__ = "04 Dec 2005"
-__version__ = "1.5.2 (rhydon)"
-__version_info__ = (1, 5, 2)
+__date__ = "06 Jul 2008"
+__version__ = "1.7.4 (Desmond)"
+__version_info__ = (1, 7, 4)
__license__ = "GNU Lesser General Public License (LGPL)"
-import transport, auth_handler, channel, rsakey, dsskey, message
-import ssh_exception, file, packet, agent, server, util
-import sftp_client, sftp_attr, sftp_handle, sftp_server, sftp_si
-
from transport import randpool, SecurityOptions, Transport
+from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, WarningPolicy
from auth_handler import AuthHandler
from channel import Channel, ChannelFile
-from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType
+from ssh_exception import SSHException, PasswordRequiredException, \
+ BadAuthenticationType, ChannelException, BadHostKeyException, \
+ AuthenticationException
from server import ServerInterface, SubsystemHandler, InteractiveQuery
from rsakey import RSAKey
from dsskey import DSSKey
@@ -88,15 +88,15 @@ from packet import Packetizer
from file import BufferedFile
from agent import Agent, AgentKey
from pkey import PKey
+from hostkeys import HostKeys
+from config import SSHConfig
# fix module names for epydoc
-for x in [Transport, SecurityOptions, Channel, SFTPServer, SSHException, \
- PasswordRequiredException, BadAuthenticationType, ChannelFile, \
- SubsystemHandler, AuthHandler, RSAKey, DSSKey, SFTPError, \
- SFTP, SFTPClient, SFTPServer, Message, Packetizer, SFTPAttributes, \
- SFTPHandle, SFTPServerInterface, BufferedFile, Agent, AgentKey, \
- PKey, BaseSFTP, SFTPFile, ServerInterface]:
- x.__module__ = 'paramiko'
+for c in locals().values():
+ if issubclass(type(c), type) or type(c).__name__ == 'classobj':
+ # classobj for exceptions :/
+ c.__module__ = __name__
+del c
from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \
OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, \
@@ -106,16 +106,24 @@ from sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, S
SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED
__all__ = [ 'Transport',
+ 'SSHClient',
+ 'MissingHostKeyPolicy',
+ 'AutoAddPolicy',
+ 'RejectPolicy',
+ 'WarningPolicy',
'SecurityOptions',
'SubsystemHandler',
'Channel',
+ 'PKey',
'RSAKey',
'DSSKey',
- 'Agent',
'Message',
'SSHException',
+ 'AuthenticationException',
'PasswordRequiredException',
'BadAuthenticationType',
+ 'ChannelException',
+ 'BadHostKeyException',
'SFTP',
'SFTPFile',
'SFTPHandle',
@@ -123,24 +131,11 @@ __all__ = [ 'Transport',
'SFTPServer',
'SFTPError',
'SFTPAttributes',
- 'SFTPServerInterface'
+ 'SFTPServerInterface',
'ServerInterface',
'BufferedFile',
'Agent',
'AgentKey',
- 'rsakey',
- 'dsskey',
- 'pkey',
- 'message',
- 'transport',
- 'sftp',
- 'sftp_client',
- 'sftp_server',
- 'sftp_attr',
- 'sftp_file',
- 'sftp_si',
- 'sftp_handle',
- 'server',
- 'file',
- 'agent',
+ 'HostKeys',
+ 'SSHConfig',
'util' ]
diff --git a/paramiko/agent.py b/paramiko/agent.py
index 3555512..71de8b8 100644
--- a/paramiko/agent.py
+++ b/paramiko/agent.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 John Rochester <john@jrochester.org>
+# Copyright (C) 2003-2007 John Rochester <john@jrochester.org>
#
# This file is part of paramiko.
#
@@ -55,20 +55,33 @@ class Agent:
@raise SSHException: if an SSH agent is found, but speaks an
incompatible protocol
"""
+ self.keys = ()
if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'):
conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- conn.connect(os.environ['SSH_AUTH_SOCK'])
+ try:
+ conn.connect(os.environ['SSH_AUTH_SOCK'])
+ except:
+ # probably a dangling env var: the ssh agent is gone
+ return
self.conn = conn
- type, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES))
- if type != SSH2_AGENT_IDENTITIES_ANSWER:
- raise SSHException('could not get keys from ssh-agent')
- keys = []
- for i in range(result.get_int()):
- keys.append(AgentKey(self, result.get_string()))
- result.get_string()
- self.keys = tuple(keys)
+ elif sys.platform == 'win32':
+ import win_pageant
+ if win_pageant.can_talk_to_agent():
+ self.conn = win_pageant.PageantConnection()
+ else:
+ return
else:
- self.keys = ()
+ # no agent support
+ return
+
+ ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES))
+ if ptype != SSH2_AGENT_IDENTITIES_ANSWER:
+ raise SSHException('could not get keys from ssh-agent')
+ keys = []
+ for i in range(result.get_int()):
+ keys.append(AgentKey(self, result.get_string()))
+ result.get_string()
+ self.keys = tuple(keys)
def close(self):
"""
@@ -132,7 +145,7 @@ class AgentKey(PKey):
msg.add_string(self.blob)
msg.add_string(data)
msg.add_int(0)
- type, result = self.agent._send_message(msg)
- if type != SSH2_AGENT_SIGN_RESPONSE:
+ ptype, result = self.agent._send_message(msg)
+ if ptype != SSH2_AGENT_SIGN_RESPONSE:
raise SSHException('key cannot be used for signing')
return result.get_string()
diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py
index 59aa376..39a0194 100644
--- a/paramiko/auth_handler.py
+++ b/paramiko/auth_handler.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -21,6 +21,7 @@ L{AuthHandler}
"""
import threading
+import weakref
# this helps freezing utils
import encodings.utf_8
@@ -28,7 +29,8 @@ import encodings.utf_8
from paramiko.common import *
from paramiko import util
from paramiko.message import Message
-from paramiko.ssh_exception import SSHException, BadAuthenticationType, PartialAuthentication
+from paramiko.ssh_exception import SSHException, AuthenticationException, \
+ BadAuthenticationType, PartialAuthentication
from paramiko.server import InteractiveQuery
@@ -38,13 +40,15 @@ class AuthHandler (object):
"""
def __init__(self, transport):
- self.transport = transport
+ self.transport = weakref.proxy(transport)
self.username = None
self.authenticated = False
self.auth_event = None
self.auth_method = ''
self.password = None
self.private_key = None
+ self.interactive_handler = None
+ self.submethods = None
# for server mode:
self.auth_username = None
self.auth_fail_count = 0
@@ -154,15 +158,15 @@ class AuthHandler (object):
event.wait(0.1)
if not self.transport.is_active():
e = self.transport.get_exception()
- if e is None:
- e = SSHException('Authentication failed.')
+ if (e is None) or issubclass(e.__class__, EOFError):
+ e = AuthenticationException('Authentication failed.')
raise e
if event.isSet():
break
if not self.is_authenticated():
e = self.transport.get_exception()
if e is None:
- e = SSHException('Authentication failed.')
+ e = AuthenticationException('Authentication failed.')
# this is horrible. python Exception isn't yet descended from
# object, so type(e) won't work. :(
if issubclass(e.__class__, PartialAuthentication):
@@ -193,7 +197,10 @@ class AuthHandler (object):
m.add_string(self.auth_method)
if self.auth_method == 'password':
m.add_boolean(False)
- m.add_string(self.password.encode('UTF-8'))
+ password = self.password
+ if isinstance(password, unicode):
+ password = password.encode('UTF-8')
+ m.add_string(password)
elif self.auth_method == 'publickey':
m.add_boolean(True)
m.add_string(self.private_key.get_name())
@@ -276,12 +283,22 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_none(username)
elif method == 'password':
changereq = m.get_boolean()
- password = m.get_string().decode('UTF-8', 'replace')
+ password = m.get_string()
+ try:
+ password = password.decode('UTF-8')
+ except UnicodeError:
+ # some clients/servers expect non-utf-8 passwords!
+ # in this case, just return the raw byte string.
+ pass
if changereq:
# always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway
self.transport._log(DEBUG, 'Auth request to change passwords (rejected)')
- newpassword = m.get_string().decode('UTF-8', 'replace')
+ newpassword = m.get_string()
+ try:
+ newpassword = newpassword.decode('UTF-8', 'replace')
+ except UnicodeError:
+ pass
result = AUTH_FAILED
else:
result = self.transport.server_object.check_auth_password(username, password)
@@ -332,7 +349,7 @@ class AuthHandler (object):
self._send_auth_result(username, method, result)
def _parse_userauth_success(self, m):
- self.transport._log(INFO, 'Authentication successful!')
+ self.transport._log(INFO, 'Authentication (%s) successful!' % self.auth_method)
self.authenticated = True
self.transport._auth_trigger()
if self.auth_event != None:
@@ -346,11 +363,11 @@ class AuthHandler (object):
self.transport._log(DEBUG, 'Methods: ' + str(authlist))
self.transport.saved_exception = PartialAuthentication(authlist)
elif self.auth_method not in authlist:
- self.transport._log(INFO, 'Authentication type not permitted.')
+ self.transport._log(INFO, 'Authentication type (%s) not permitted.' % self.auth_method)
self.transport._log(DEBUG, 'Allowed methods: ' + str(authlist))
self.transport.saved_exception = BadAuthenticationType('Bad authentication type', authlist)
else:
- self.transport._log(INFO, 'Authentication failed.')
+ self.transport._log(INFO, 'Authentication (%s) failed.' % self.auth_method)
self.authenticated = False
self.username = None
if self.auth_event != None:
@@ -407,4 +424,3 @@ class AuthHandler (object):
MSG_USERAUTH_INFO_RESPONSE: _parse_userauth_info_response,
}
-
diff --git a/paramiko/ber.py b/paramiko/ber.py
index 6a7823d..9d8ddfa 100644
--- a/paramiko/ber.py
+++ b/paramiko/ber.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -16,7 +16,7 @@
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
-import struct
+
import util
@@ -91,8 +91,9 @@ class BER(object):
while True:
x = b.decode_next()
if x is None:
- return out
+ break
out.append(x)
+ return out
decode_sequence = staticmethod(decode_sequence)
def encode_tlv(self, ident, val):
diff --git a/paramiko/buffered_pipe.py b/paramiko/buffered_pipe.py
new file mode 100644
index 0000000..ae3d9d6
--- /dev/null
+++ b/paramiko/buffered_pipe.py
@@ -0,0 +1,200 @@
+# Copyright (C) 2006-2007 Robey Pointer <robey@lag.net>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+
+"""
+Attempt to generalize the "feeder" part of a Channel: an object which can be
+read from and closed, but is reading from a buffer fed by another thread. The
+read operations are blocking and can have a timeout set.
+"""
+
+import array
+import threading
+import time
+
+
+class PipeTimeout (IOError):
+ """
+ Indicates that a timeout was reached on a read from a L{BufferedPipe}.
+ """
+ pass
+
+
+class BufferedPipe (object):
+ """
+ A buffer that obeys normal read (with timeout) & close semantics for a
+ file or socket, but is fed data from another thread. This is used by
+ L{Channel}.
+ """
+
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._cv = threading.Condition(self._lock)
+ self._event = None
+ self._buffer = array.array('B')
+ self._closed = False
+
+ def set_event(self, event):
+ """
+ Set an event on this buffer. When data is ready to be read (or the
+ buffer has been closed), the event will be set. When no data is
+ ready, the event will be cleared.
+
+ @param event: the event to set/clear
+ @type event: Event
+ """
+ self._event = event
+ if len(self._buffer) > 0:
+ event.set()
+ else:
+ event.clear()
+
+ def feed(self, data):
+ """
+ Feed new data into this pipe. This method is assumed to be called
+ from a separate thread, so synchronization is done.
+
+ @param data: the data to add
+ @type data: str
+ """
+ self._lock.acquire()
+ try:
+ if self._event is not None:
+ self._event.set()
+ self._buffer.fromstring(data)
+ self._cv.notifyAll()
+ finally:
+ self._lock.release()
+
+ def read_ready(self):
+ """
+ Returns true if data is buffered and ready to be read from this
+ feeder. A C{False} result does not mean that the feeder has closed;
+ it means you may need to wait before more data arrives.
+
+ @return: C{True} if a L{read} call would immediately return at least
+ one byte; C{False} otherwise.
+ @rtype: bool
+ """
+ self._lock.acquire()
+ try:
+ if len(self._buffer) == 0:
+ return False
+ return True
+ finally:
+ self._lock.release()
+
+ def read(self, nbytes, timeout=None):
+ """
+ Read data from the pipe. The return value is a string representing
+ the data received. The maximum amount of data to be received at once
+ is specified by C{nbytes}. If a string of length zero is returned,
+ the pipe has been closed.
+
+ The optional C{timeout} argument can be a nonnegative float expressing
+ seconds, or C{None} for no timeout. If a float is given, a
+ C{PipeTimeout} will be raised if the timeout period value has
+ elapsed before any data arrives.
+
+ @param nbytes: maximum number of bytes to read
+ @type nbytes: int
+ @param timeout: maximum seconds to wait (or C{None}, the default, to
+ wait forever)
+ @type timeout: float
+ @return: data
+ @rtype: str
+
+ @raise PipeTimeout: if a timeout was specified and no data was ready
+ before that timeout
+ """
+ out = ''
+ self._lock.acquire()
+ try:
+ if len(self._buffer) == 0:
+ if self._closed:
+ return out
+ # should we block?
+ if timeout == 0.0:
+ raise PipeTimeout()
+ # loop here in case we get woken up but a different thread has
+ # grabbed everything in the buffer.
+ while (len(self._buffer) == 0) and not self._closed:
+ then = time.time()
+ self._cv.wait(timeout)
+ if timeout is not None:
+ timeout -= time.time() - then
+ if timeout <= 0.0:
+ raise PipeTimeout()
+
+ # something's in the buffer and we have the lock!
+ if len(self._buffer) <= nbytes:
+ out = self._buffer.tostring()
+ del self._buffer[:]
+ if (self._event is not None) and not self._closed:
+ self._event.clear()
+ else:
+ out = self._buffer[:nbytes].tostring()
+ del self._buffer[:nbytes]
+ finally:
+ self._lock.release()
+
+ return out
+
+ def empty(self):
+ """
+ Clear out the buffer and return all data that was in it.
+
+ @return: any data that was in the buffer prior to clearing it out
+ @rtype: str
+ """
+ self._lock.acquire()
+ try:
+ out = self._buffer.tostring()
+ del self._buffer[:]
+ if (self._event is not None) and not self._closed:
+ self._event.clear()
+ return out
+ finally:
+ self._lock.release()
+
+ def close(self):
+ """
+ Close this pipe object. Future calls to L{read} after the buffer
+ has been emptied will return immediately with an empty string.
+ """
+ self._lock.acquire()
+ try:
+ self._closed = True
+ self._cv.notifyAll()
+ if self._event is not None:
+ self._event.set()
+ finally:
+ self._lock.release()
+
+ def __len__(self):
+ """
+ Return the number of bytes buffered.
+
+ @return: number of bytes bufferes
+ @rtype: int
+ """
+ self._lock.acquire()
+ try:
+ return len(self._buffer)
+ finally:
+ self._lock.release()
+
diff --git a/paramiko/channel.py b/paramiko/channel.py
index 8a00233..910a03c 100644
--- a/paramiko/channel.py
+++ b/paramiko/channel.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -20,6 +20,7 @@
Abstraction for an SSH2 channel.
"""
+import binascii
import sys
import time
import threading
@@ -31,9 +32,14 @@ from paramiko import util
from paramiko.message import Message
from paramiko.ssh_exception import SSHException
from paramiko.file import BufferedFile
+from paramiko.buffered_pipe import BufferedPipe, PipeTimeout
from paramiko import pipe
+# lower bound on the max packet size we'll accept from the remote host
+MIN_PACKET_SIZE = 1024
+
+
class Channel (object):
"""
A secure tunnel across an SSH L{Transport}. A Channel is meant to behave
@@ -49,9 +55,6 @@ class Channel (object):
is exactly like a normal network socket, so it shouldn't be too surprising.
"""
- # lower bound on the max packet size we'll accept from the remote host
- MIN_PACKET_SIZE = 1024
-
def __init__(self, chanid):
"""
Create a new channel. The channel is not associated with any
@@ -69,14 +72,12 @@ class Channel (object):
self.active = False
self.eof_received = 0
self.eof_sent = 0
- self.in_buffer = ''
- self.in_stderr_buffer = ''
+ self.in_buffer = BufferedPipe()
+ self.in_stderr_buffer = BufferedPipe()
self.timeout = None
self.closed = False
self.ultra_debug = False
self.lock = threading.Lock()
- self.in_buffer_cv = threading.Condition(self.lock)
- self.in_stderr_buffer_cv = threading.Condition(self.lock)
self.out_buffer_cv = threading.Condition(self.lock)
self.in_window_size = 0
self.out_window_size = 0
@@ -85,15 +86,19 @@ class Channel (object):
self.in_window_threshold = 0
self.in_window_sofar = 0
self.status_event = threading.Event()
- self.name = str(chanid)
- self.logger = util.get_logger('paramiko.chan.' + str(chanid))
- self.pipe = None
+ self._name = str(chanid)
+ self.logger = util.get_logger('paramiko.transport')
+ self._pipe = None
self.event = threading.Event()
self.combine_stderr = False
self.exit_status = -1
+ self.origin_addr = None
def __del__(self):
- self.close()
+ try:
+ self.close()
+ except:
+ pass
def __repr__(self):
"""
@@ -124,14 +129,15 @@ class Channel (object):
It isn't necessary (or desirable) to call this method if you're going
to exectue a single command with L{exec_command}.
- @param term: the terminal type to emulate (for example, C{'vt100'}).
+ @param term: the terminal type to emulate (for example, C{'vt100'})
@type term: str
@param width: width (in characters) of the terminal screen
@type width: int
@param height: height (in characters) of the terminal screen
@type height: int
- @return: C{True} if the operation succeeded; C{False} if not.
- @rtype: bool
+
+ @raise SSHException: if the request was rejected or the channel was
+ closed
"""
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
@@ -148,12 +154,7 @@ class Channel (object):
m.add_string('')
self.event.clear()
self.transport._send_user_message(m)
- while True:
- self.event.wait(0.1)
- if self.closed:
- return False
- if self.event.isSet():
- return True
+ self._wait_for_event()
def invoke_shell(self):
"""
@@ -168,8 +169,8 @@ class Channel (object):
When the shell exits, the channel will be closed and can't be reused.
You must open a new channel if you wish to open another shell.
- @return: C{True} if the operation succeeded; C{False} if not.
- @rtype: bool
+ @raise SSHException: if the request was rejected or the channel was
+ closed
"""
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
@@ -180,12 +181,7 @@ class Channel (object):
m.add_boolean(1)
self.event.clear()
self.transport._send_user_message(m)
- while True:
- self.event.wait(0.1)
- if self.closed:
- return False
- if self.event.isSet():
- return True
+ self._wait_for_event()
def exec_command(self, command):
"""
@@ -199,8 +195,9 @@ class Channel (object):
@param command: a shell command to execute.
@type command: str
- @return: C{True} if the operation succeeded; C{False} if not.
- @rtype: bool
+
+ @raise SSHException: if the request was rejected or the channel was
+ closed
"""
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
@@ -208,16 +205,11 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.remote_chanid)
m.add_string('exec')
- m.add_boolean(1)
+ m.add_boolean(True)
m.add_string(command)
self.event.clear()
self.transport._send_user_message(m)
- while True:
- self.event.wait(0.1)
- if self.closed:
- return False
- if self.event.isSet():
- return True
+ self._wait_for_event()
def invoke_subsystem(self, subsystem):
"""
@@ -230,8 +222,9 @@ class Channel (object):
@param subsystem: name of the subsystem being requested.
@type subsystem: str
- @return: C{True} if the operation succeeded; C{False} if not.
- @rtype: bool
+
+ @raise SSHException: if the request was rejected or the channel was
+ closed
"""
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
@@ -239,16 +232,11 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.remote_chanid)
m.add_string('subsystem')
- m.add_boolean(1)
+ m.add_boolean(True)
m.add_string(subsystem)
self.event.clear()
self.transport._send_user_message(m)
- while True:
- self.event.wait(0.1)
- if self.closed:
- return False
- if self.event.isSet():
- return True
+ self._wait_for_event()
def resize_pty(self, width=80, height=24):
"""
@@ -259,8 +247,9 @@ class Channel (object):
@type width: int
@param height: new height (in characters) of the terminal screen
@type height: int
- @return: C{True} if the operation succeeded; C{False} if not.
- @rtype: bool
+
+ @raise SSHException: if the request was rejected or the channel was
+ closed
"""
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
@@ -268,19 +257,27 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.remote_chanid)
m.add_string('window-change')
- m.add_boolean(1)
+ m.add_boolean(True)
m.add_int(width)
m.add_int(height)
m.add_int(0).add_int(0)
self.event.clear()
self.transport._send_user_message(m)
- while True:
- self.event.wait(0.1)
- if self.closed:
- return False
- if self.event.isSet():
- return True
+ self._wait_for_event()
+ def exit_status_ready(self):
+ """
+ Return true if the remote process has exited and returned an exit
+ status. You may use this to poll the process status if you don't
+ want to block in L{recv_exit_status}. Note that the server may not
+ return an exit status in some cases (like bad servers).
+
+ @return: True if L{recv_exit_status} will return immediately
+ @rtype: bool
+ @since: 1.7.3
+ """
+ return self.closed or self.status_event.isSet()
+
def recv_exit_status(self):
"""
Return the exit status from the process on the server. This is
@@ -296,8 +293,9 @@ class Channel (object):
"""
while True:
if self.closed or self.status_event.isSet():
- return self.exit_status
+ break
self.status_event.wait(0.1)
+ return self.exit_status
def send_exit_status(self, status):
"""
@@ -317,10 +315,73 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.remote_chanid)
m.add_string('exit-status')
- m.add_boolean(0)
+ m.add_boolean(False)
m.add_int(status)
self.transport._send_user_message(m)
+
+ def request_x11(self, screen_number=0, auth_protocol=None, auth_cookie=None,
+ single_connection=False, handler=None):
+ """
+ Request an x11 session on this channel. If the server allows it,
+ further x11 requests can be made from the server to the client,
+ when an x11 application is run in a shell session.
+
+ From RFC4254::
+
+ It is RECOMMENDED that the 'x11 authentication cookie' that is
+ sent be a fake, random cookie, and that the cookie be checked and
+ replaced by the real cookie when a connection request is received.
+ If you omit the auth_cookie, a new secure random 128-bit value will be
+ generated, used, and returned. You will need to use this value to
+ verify incoming x11 requests and replace them with the actual local
+ x11 cookie (which requires some knoweldge of the x11 protocol).
+
+ If a handler is passed in, the handler is called from another thread
+ whenever a new x11 connection arrives. The default handler queues up
+ incoming x11 connections, which may be retrieved using
+ L{Transport.accept}. The handler's calling signature is::
+
+ handler(channel: Channel, (address: str, port: int))
+
+ @param screen_number: the x11 screen number (0, 10, etc)
+ @type screen_number: int
+ @param auth_protocol: the name of the X11 authentication method used;
+ if none is given, C{"MIT-MAGIC-COOKIE-1"} is used
+ @type auth_protocol: str
+ @param auth_cookie: hexadecimal string containing the x11 auth cookie;
+ if none is given, a secure random 128-bit value is generated
+ @type auth_cookie: str
+ @param single_connection: if True, only a single x11 connection will be
+ forwarded (by default, any number of x11 connections can arrive
+ over this session)
+ @type single_connection: bool
+ @param handler: an optional handler to use for incoming X11 connections
+ @type handler: function
+ @return: the auth_cookie used
+ """
+ if self.closed or self.eof_received or self.eof_sent or not self.active:
+ raise SSHException('Channel is not open')
+ if auth_protocol is None:
+ auth_protocol = 'MIT-MAGIC-COOKIE-1'
+ if auth_cookie is None:
+ auth_cookie = binascii.hexlify(self.transport.randpool.get_bytes(16))
+
+ m = Message()
+ m.add_byte(chr(MSG_CHANNEL_REQUEST))
+ m.add_int(self.remote_chanid)
+ m.add_string('x11-req')
+ m.add_boolean(True)
+ m.add_boolean(single_connection)
+ m.add_string(auth_protocol)
+ m.add_string(auth_cookie)
+ m.add_int(screen_number)
+ self.event.clear()
+ self.transport._send_user_message(m)
+ self._wait_for_event()
+ self.transport._set_x11_handler(handler)
+ return auth_cookie
+
def get_transport(self):
"""
Return the L{Transport} associated with this channel.
@@ -333,14 +394,13 @@ class Channel (object):
def set_name(self, name):
"""
Set a name for this channel. Currently it's only used to set the name
- of the log level used for debugging. The name can be fetched with the
+ of the channel in logfile entries. The name can be fetched with the
L{get_name} method.
- @param name: new channel name.
+ @param name: new channel name
@type name: str
"""
- self.name = name
- self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self.name)
+ self._name = name
def get_name(self):
"""
@@ -349,7 +409,7 @@ class Channel (object):
@return: the name of this channel.
@rtype: str
"""
- return self.name
+ return self._name
def get_id(self):
"""
@@ -360,8 +420,6 @@ class Channel (object):
@return: the ID of this channel.
@rtype: int
-
- @since: ivysaur
"""
return self.chanid
@@ -394,8 +452,7 @@ class Channel (object):
self.combine_stderr = combine
if combine and not old:
# copy old stderr buffer into primary buffer
- data = self.in_stderr_buffer
- self.in_stderr_buffer = ''
+ data = self.in_stderr_buffer.empty()
finally:
self.lock.release()
if len(data) > 0:
@@ -419,7 +476,7 @@ class Channel (object):
C{chan.settimeout(None)} is equivalent to C{chan.setblocking(1)}.
@param timeout: seconds to wait for a pending read/write operation
- before raising C{socket.timeout}, or C{None} for no timeout.
+ before raising C{socket.timeout}, or C{None} for no timeout.
@type timeout: float
"""
self.timeout = timeout
@@ -439,17 +496,19 @@ class Channel (object):
"""
Set blocking or non-blocking mode of the channel: if C{blocking} is 0,
the channel is set to non-blocking mode; otherwise it's set to blocking
- mode. Initially all channels are in blocking mode.
+ mode. Initially all channels are in blocking mode.
In non-blocking mode, if a L{recv} call doesn't find any data, or if a
L{send} call can't immediately dispose of the data, an error exception
- is raised. In blocking mode, the calls block until they can proceed.
+ is raised. In blocking mode, the calls block until they can proceed. An
+ EOF condition is considered "immediate data" for L{recv}, so if the
+ channel is closed in the read direction, it will never block.
C{chan.setblocking(0)} is equivalent to C{chan.settimeout(0)};
C{chan.setblocking(1)} is equivalent to C{chan.settimeout(None)}.
@param blocking: 0 to set non-blocking mode; non-0 to set blocking
- mode.
+ mode.
@type blocking: int
"""
if blocking:
@@ -457,6 +516,18 @@ class Channel (object):
else:
self.settimeout(0.0)
+ def getpeername(self):
+ """
+ Return the address of the remote side of this Channel, if possible.
+ This is just a wrapper around C{'getpeername'} on the Transport, used
+ to provide enough of a socket-like interface to allow asyncore to work.
+ (asyncore likes to call C{'getpeername'}.)
+
+ @return: the address if the remote host, if known
+ @rtype: tuple(str, int)
+ """
+ return self.transport.getpeername()
+
def close(self):
"""
Close the channel. All future read/write operations on the channel
@@ -466,15 +537,17 @@ class Channel (object):
"""
self.lock.acquire()
try:
+ # only close the pipe when the user explicitly closes the channel.
+ # otherwise they will get unpleasant surprises. (and do it before
+ # checking self.closed, since the remote host may have already
+ # closed the connection.)
+ if self._pipe is not None:
+ self._pipe.close()
+ self._pipe = None
+
if not self.active or self.closed:
return
msgs = self._close_internal()
-
- # only close the pipe when the user explicitly closes the channel.
- # otherwise they will get unpleasant surprises.
- if self.pipe is not None:
- self.pipe.close()
- self.pipe = None
finally:
self.lock.release()
for m in msgs:
@@ -491,13 +564,7 @@ class Channel (object):
return at least one byte; C{False} otherwise.
@rtype: boolean
"""
- self.lock.acquire()
- try:
- if len(self.in_buffer) == 0:
- return False
- return True
- finally:
- self.lock.release()
+ return self.in_buffer.read_ready()
def recv(self, nbytes):
"""
@@ -514,38 +581,12 @@ class Channel (object):
@raise socket.timeout: if no data is ready before the timeout set by
L{settimeout}.
"""
- out = ''
- self.lock.acquire()
try:
- if len(self.in_buffer) == 0:
- if self.closed or self.eof_received:
- return out
- # should we block?
- if self.timeout == 0.0:
- raise socket.timeout()
- # loop here in case we get woken up but a different thread has grabbed everything in the buffer
- timeout = self.timeout
- while (len(self.in_buffer) == 0) and not self.closed and not self.eof_received:
- then = time.time()
- self.in_buffer_cv.wait(timeout)
- if timeout != None:
- timeout -= time.time() - then
- if timeout <= 0.0:
- raise socket.timeout()
- # something in the buffer and we have the lock
- if len(self.in_buffer) <= nbytes:
- out = self.in_buffer
- self.in_buffer = ''
- if self.pipe is not None:
- # clear the pipe, since no more data is buffered
- self.pipe.clear()
- else:
- out = self.in_buffer[:nbytes]
- self.in_buffer = self.in_buffer[nbytes:]
- ack = self._check_add_window(len(out))
- finally:
- self.lock.release()
+ out = self.in_buffer.read(nbytes, self.timeout)
+ except PipeTimeout, e:
+ raise socket.timeout()
+ ack = self._check_add_window(len(out))
# no need to hold the channel lock when sending this
if ack > 0:
m = Message()
@@ -569,13 +610,7 @@ class Channel (object):
@since: 1.1
"""
- self.lock.acquire()
- try:
- if len(self.in_stderr_buffer) == 0:
- return False
- return True
- finally:
- self.lock.release()
+ return self.in_stderr_buffer.read_ready()
def recv_stderr(self, nbytes):
"""
@@ -596,36 +631,43 @@ class Channel (object):
@since: 1.1
"""
- out = ''
+ try:
+ out = self.in_stderr_buffer.read(nbytes, self.timeout)
+ except PipeTimeout, e:
+ raise socket.timeout()
+
+ ack = self._check_add_window(len(out))
+ # no need to hold the channel lock when sending this
+ if ack > 0:
+ m = Message()
+ m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
+ m.add_int(self.remote_chanid)
+ m.add_int(ack)
+ self.transport._send_user_message(m)
+
+ return out
+
+ def send_ready(self):
+ """
+ Returns true if data can be written to this channel without blocking.
+ This means the channel is either closed (so any write attempt would
+ return immediately) or there is at least one byte of space in the
+ outbound buffer. If there is at least one byte of space in the
+ outbound buffer, a L{send} call will succeed immediately and return
+ the number of bytes actually written.
+
+ @return: C{True} if a L{send} call on this channel would immediately
+ succeed or fail
+ @rtype: boolean
+ """
self.lock.acquire()
try:
- if len(self.in_stderr_buffer) == 0:
- if self.closed or self.eof_received:
- return out
- # should we block?
- if self.timeout == 0.0:
- raise socket.timeout()
- # loop here in case we get woken up but a different thread has grabbed everything in the buffer
- timeout = self.timeout
- while (len(self.in_stderr_buffer) == 0) and not self.closed and not self.eof_received:
- then = time.time()
- self.in_stderr_buffer_cv.wait(timeout)
- if timeout != None:
- timeout -= time.time() - then
- if timeout <= 0.0:
- raise socket.timeout()
- # something in the buffer and we have the lock
- if len(self.in_stderr_buffer) <= nbytes:
- out = self.in_stderr_buffer
- self.in_stderr_buffer = ''
- else:
- out = self.in_stderr_buffer[:nbytes]
- self.in_stderr_buffer = self.in_stderr_buffer[nbytes:]
- self._check_add_window(len(out))
+ if self.closed or self.eof_sent:
+ return True
+ return self.out_window_size > 0
finally:
self.lock.release()
- return out
-
+
def send(self, s):
"""
Send data to the channel. Returns the number of bytes sent, or 0 if
@@ -634,9 +676,9 @@ class Channel (object):
transmitted, the application needs to attempt delivery of the remaining
data.
- @param s: data to send.
+ @param s: data to send
@type s: str
- @return: number of bytes actually sent.
+ @return: number of bytes actually sent
@rtype: int
@raise socket.timeout: if no data could be sent before the timeout set
@@ -653,9 +695,11 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_DATA))
m.add_int(self.remote_chanid)
m.add_string(s[:size])
- self.transport._send_user_message(m)
finally:
self.lock.release()
+ # Note: We release self.lock before calling _send_user_message.
+ # Otherwise, we can deadlock during re-keying.
+ self.transport._send_user_message(m)
return size
def send_stderr(self, s):
@@ -689,9 +733,11 @@ class Channel (object):
m.add_int(self.remote_chanid)
m.add_int(1)
m.add_string(s[:size])
- self.transport._send_user_message(m)
finally:
self.lock.release()
+ # Note: We release self.lock before calling _send_user_message.
+ # Otherwise, we can deadlock during re-keying.
+ self.transport._send_user_message(m)
return size
def sendall(self, s):
@@ -776,14 +822,14 @@ class Channel (object):
def fileno(self):
"""
Returns an OS-level file descriptor which can be used for polling, but
- but I{not} for reading or writing). This is primaily to allow python's
+ but I{not} for reading or writing. This is primaily to allow python's
C{select} module to work.
The first time C{fileno} is called on a channel, a pipe is created to
simulate real OS-level file descriptor (FD) behavior. Because of this,
two OS-level FDs are created, which will use up FDs faster than normal.
- You won't notice this effect unless you open hundreds or thousands of
- channels simultaneously, but it's still notable.
+ (You won't notice this effect unless you have hundreds of channels
+ open at the same time.)
@return: an OS-level file descriptor
@rtype: int
@@ -793,13 +839,14 @@ class Channel (object):
"""
self.lock.acquire()
try:
- if self.pipe is not None:
- return self.pipe.fileno()
+ if self._pipe is not None:
+ return self._pipe.fileno()
# create the pipe and feed in any existing data
- self.pipe = pipe.make_pipe()
- if len(self.in_buffer) > 0:
- self.pipe.set()
- return self.pipe.fileno()
+ self._pipe = pipe.make_pipe()
+ p1, p2 = pipe.make_or_pipe(self._pipe)
+ self.in_buffer.set_event(p1)
+ self.in_stderr_buffer.set_event(p2)
+ return self._pipe.fileno()
finally:
self.lock.release()
@@ -856,7 +903,7 @@ class Channel (object):
def _set_transport(self, transport):
self.transport = transport
- self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self.name)
+ self.logger = util.get_logger(self.transport.get_log_channel())
def _set_window(self, window_size, max_packet_size):
self.in_window_size = window_size
@@ -869,7 +916,7 @@ class Channel (object):
def _set_remote_channel(self, chanid, window_size, max_packet_size):
self.remote_chanid = chanid
self.out_window_size = window_size
- self.out_max_packet_size = max(max_packet_size, self.MIN_PACKET_SIZE)
+ self.out_max_packet_size = max(max_packet_size, MIN_PACKET_SIZE)
self.active = 1
self._log(DEBUG, 'Max packet out: %d bytes' % max_packet_size)
@@ -894,16 +941,7 @@ class Channel (object):
s = m
else:
s = m.get_string()
- self.lock.acquire()
- try:
- if self.ultra_debug:
- self._log(DEBUG, 'fed %d bytes' % len(s))
- if self.pipe is not None:
- self.pipe.set()
- self.in_buffer += s
- self.in_buffer_cv.notifyAll()
- finally:
- self.lock.release()
+ self.in_buffer.feed(s)
def _feed_extended(self, m):
code = m.get_int()
@@ -912,15 +950,9 @@ class Channel (object):
self._log(ERROR, 'unknown extended_data type %d; discarding' % code)
return
if self.combine_stderr:
- return self._feed(s)
- self.lock.acquire()
- try:
- if self.ultra_debug:
- self._log(DEBUG, 'fed %d stderr bytes' % len(s))
- self.in_stderr_buffer += s
- self.in_stderr_buffer_cv.notifyAll()
- finally:
- self.lock.release()
+ self._feed(s)
+ else:
+ self.in_stderr_buffer.feed(s)
def _window_adjust(self, m):
nbytes = m.get_int()
@@ -984,6 +1016,16 @@ class Channel (object):
else:
ok = server.check_channel_window_change_request(self, width, height, pixelwidth,
pixelheight)
+ elif key == 'x11-req':
+ single_connection = m.get_boolean()
+ auth_proto = m.get_string()
+ auth_cookie = m.get_string()
+ screen_number = m.get_int()
+ if server is None:
+ ok = False
+ else:
+ ok = server.check_channel_x11_request(self, single_connection,
+ auth_proto, auth_cookie, screen_number)
else:
self._log(DEBUG, 'Unhandled channel request "%s"' % key)
ok = False
@@ -1001,13 +1043,13 @@ class Channel (object):
try:
if not self.eof_received:
self.eof_received = True
- self.in_buffer_cv.notifyAll()
- self.in_stderr_buffer_cv.notifyAll()
- if self.pipe is not None:
- self.pipe.set_forever()
+ self.in_buffer.close()
+ self.in_stderr_buffer.close()
+ if self._pipe is not None:
+ self._pipe.set_forever()
finally:
self.lock.release()
- self._log(DEBUG, 'EOF received')
+ self._log(DEBUG, 'EOF received (%s)', self._name)
def _handle_close(self, m):
self.lock.acquire()
@@ -1024,17 +1066,29 @@ class Channel (object):
### internals...
- def _log(self, level, msg):
- self.logger.log(level, msg)
+ def _log(self, level, msg, *args):
+ self.logger.log(level, "[chan " + self._name + "] " + msg, *args)
+
+ def _wait_for_event(self):
+ while True:
+ self.event.wait(0.1)
+ if self.event.isSet():
+ return
+ if self.closed:
+ e = self.transport.get_exception()
+ if e is None:
+ e = SSHException('Channel closed.')
+ raise e
+ return
def _set_closed(self):
# you are holding the lock.
self.closed = True
- self.in_buffer_cv.notifyAll()
- self.in_stderr_buffer_cv.notifyAll()
+ self.in_buffer.close()
+ self.in_stderr_buffer.close()
self.out_buffer_cv.notifyAll()
- if self.pipe is not None:
- self.pipe.set_forever()
+ if self._pipe is not None:
+ self._pipe.set_forever()
def _send_eof(self):
# you are holding the lock.
@@ -1044,7 +1098,7 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_EOF))
m.add_int(self.remote_chanid)
self.eof_sent = True
- self._log(DEBUG, 'EOF sent')
+ self._log(DEBUG, 'EOF sent (%s)', self._name)
return m
def _close_internal(self):
@@ -1072,19 +1126,22 @@ class Channel (object):
self.lock.release()
def _check_add_window(self, n):
- # already holding the lock!
- if self.closed or self.eof_received or not self.active:
- return 0
- if self.ultra_debug:
- self._log(DEBUG, 'addwindow %d' % n)
- self.in_window_sofar += n
- if self.in_window_sofar <= self.in_window_threshold:
- return 0
- if self.ultra_debug:
- self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar)
- out = self.in_window_sofar
- self.in_window_sofar = 0
- return out
+ self.lock.acquire()
+ try:
+ if self.closed or self.eof_received or not self.active:
+ return 0
+ if self.ultra_debug:
+ self._log(DEBUG, 'addwindow %d' % n)
+ self.in_window_sofar += n
+ if self.in_window_sofar <= self.in_window_threshold:
+ return 0
+ if self.ultra_debug:
+ self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar)
+ out = self.in_window_sofar
+ self.in_window_sofar = 0
+ return out
+ finally:
+ self.lock.release()
def _wait_for_send_window(self, size):
"""
@@ -1155,8 +1212,6 @@ class ChannelFile (BufferedFile):
def _write(self, data):
self.channel.sendall(data)
return len(data)
-
- seek = BufferedFile.seek
class ChannelStderrFile (ChannelFile):
diff --git a/paramiko/client.py b/paramiko/client.py
new file mode 100644
index 0000000..7870ea9
--- /dev/null
+++ b/paramiko/client.py
@@ -0,0 +1,474 @@
+# Copyright (C) 2006-2007 Robey Pointer <robey@lag.net>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+
+"""
+L{SSHClient}.
+"""
+
+from binascii import hexlify
+import getpass
+import os
+import socket
+import warnings
+
+from paramiko.agent import Agent
+from paramiko.common import *
+from paramiko.dsskey import DSSKey
+from paramiko.hostkeys import HostKeys
+from paramiko.resource import ResourceManager
+from paramiko.rsakey import RSAKey
+from paramiko.ssh_exception import SSHException, BadHostKeyException
+from paramiko.transport import Transport
+
+
+class MissingHostKeyPolicy (object):
+ """
+ Interface for defining the policy that L{SSHClient} should use when the
+ SSH server's hostname is not in either the system host keys or the
+ application's keys. Pre-made classes implement policies for automatically
+ adding the key to the application's L{HostKeys} object (L{AutoAddPolicy}),
+ and for automatically rejecting the key (L{RejectPolicy}).
+
+ This function may be used to ask the user to verify the key, for example.
+ """
+
+ def missing_host_key(self, client, hostname, key):
+ """
+ Called when an L{SSHClient} receives a server key for a server that
+ isn't in either the system or local L{HostKeys} object. To accept
+ the key, simply return. To reject, raised an exception (which will
+ be passed to the calling application).
+ """
+ pass
+
+
+class AutoAddPolicy (MissingHostKeyPolicy):
+ """
+ Policy for automatically adding the hostname and new host key to the
+ local L{HostKeys} object, and saving it. This is used by L{SSHClient}.
+ """
+
+ def missing_host_key(self, client, hostname, key):
+ client._host_keys.add(hostname, key.get_name(), key)
+ if client._host_keys_filename is not None:
+ client.save_host_keys(client._host_keys_filename)
+ client._log(DEBUG, 'Adding %s host key for %s: %s' %
+ (key.get_name(), hostname, hexlify(key.get_fingerprint())))
+
+
+class RejectPolicy (MissingHostKeyPolicy):
+ """
+ Policy for automatically rejecting the unknown hostname & key. This is
+ used by L{SSHClient}.
+ """
+
+ def missing_host_key(self, client, hostname, key):
+ client._log(DEBUG, 'Rejecting %s host key for %s: %s' %
+ (key.get_name(), hostname, hexlify(key.get_fingerprint())))
+ raise SSHException('Unknown server %s' % hostname)
+
+
+class WarningPolicy (MissingHostKeyPolicy):
+ """
+ Policy for logging a python-style warning for an unknown host key, but
+ accepting it. This is used by L{SSHClient}.
+ """
+ def missing_host_key(self, client, hostname, key):
+ warnings.warn('Unknown %s host key for %s: %s' %
+ (key.get_name(), hostname, hexlify(key.get_fingerprint())))
+
+
+class SSHClient (object):
+ """
+ A high-level representation of a session with an SSH server. This class
+ wraps L{Transport}, L{Channel}, and L{SFTPClient} to take care of most
+ aspects of authenticating and opening channels. A typical use case is::
+
+ client = SSHClient()
+ client.load_system_host_keys()
+ client.connect('ssh.example.com')
+ stdin, stdout, stderr = client.exec_command('ls -l')
+
+ You may pass in explicit overrides for authentication and server host key
+ checking. The default mechanism is to try to use local key files or an
+ SSH agent (if one is running).
+
+ @since: 1.6
+ """
+
+ def __init__(self):
+ """
+ Create a new SSHClient.
+ """
+ self._system_host_keys = HostKeys()
+ self._host_keys = HostKeys()
+ self._host_keys_filename = None
+ self._log_channel = None
+ self._policy = RejectPolicy()
+ self._transport = None
+
+ def load_system_host_keys(self, filename=None):
+ """
+ Load host keys from a system (read-only) file. Host keys read with
+ this method will not be saved back by L{save_host_keys}.
+
+ This method can be called multiple times. Each new set of host keys
+ will be merged with the existing set (new replacing old if there are
+ conflicts).
+
+ If C{filename} is left as C{None}, an attempt will be made to read
+ keys from the user's local "known hosts" file, as used by OpenSSH,
+ and no exception will be raised if the file can't be read. This is
+ probably only useful on posix.
+
+ @param filename: the filename to read, or C{None}
+ @type filename: str
+
+ @raise IOError: if a filename was provided and the file could not be
+ read
+ """
+ if filename is None:
+ # try the user's .ssh key file, and mask exceptions
+ filename = os.path.expanduser('~/.ssh/known_hosts')
+ try:
+ self._system_host_keys.load(filename)
+ except IOError:
+ pass
+ return
+ self._system_host_keys.load(filename)
+
+ def load_host_keys(self, filename):
+ """
+ Load host keys from a local host-key file. Host keys read with this
+ method will be checked I{after} keys loaded via L{load_system_host_keys},
+ but will be saved back by L{save_host_keys} (so they can be modified).
+ The missing host key policy L{AutoAddPolicy} adds keys to this set and
+ saves them, when connecting to a previously-unknown server.
+
+ This method can be called multiple times. Each new set of host keys
+ will be merged with the existing set (new replacing old if there are
+ conflicts). When automatically saving, the last hostname is used.
+
+ @param filename: the filename to read
+ @type filename: str
+
+ @raise IOError: if the filename could not be read
+ """
+ self._host_keys_filename = filename
+ self._host_keys.load(filename)
+
+ def save_host_keys(self, filename):
+ """
+ Save the host keys back to a file. Only the host keys loaded with
+ L{load_host_keys} (plus any added directly) will be saved -- not any
+ host keys loaded with L{load_system_host_keys}.
+
+ @param filename: the filename to save to
+ @type filename: str
+
+ @raise IOError: if the file could not be written
+ """
+ f = open(filename, 'w')
+ f.write('# SSH host keys collected by paramiko\n')
+ for hostname, keys in self._host_keys.iteritems():
+ for keytype, key in keys.iteritems():
+ f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
+ f.close()
+
+ def get_host_keys(self):
+ """
+ Get the local L{HostKeys} object. This can be used to examine the
+ local host keys or change them.
+
+ @return: the local host keys
+ @rtype: L{HostKeys}
+ """
+ return self._host_keys
+
+ def set_log_channel(self, name):
+ """
+ Set the channel for logging. The default is C{"paramiko.transport"}
+ but it can be set to anything you want.
+
+ @param name: new channel name for logging
+ @type name: str
+ """
+ self._log_channel = name
+
+ def set_missing_host_key_policy(self, policy):
+ """
+ Set the policy to use when connecting to a server that doesn't have a
+ host key in either the system or local L{HostKeys} objects. The
+ default policy is to reject all unknown servers (using L{RejectPolicy}).
+ You may substitute L{AutoAddPolicy} or write your own policy class.
+
+ @param policy: the policy to use when receiving a host key from a
+ previously-unknown server
+ @type policy: L{MissingHostKeyPolicy}
+ """
+ self._policy = policy
+
+ def connect(self, hostname, port=22, username=None, password=None, pkey=None,
+ key_filename=None, timeout=None, allow_agent=True, look_for_keys=True):
+ """
+ Connect to an SSH server and authenticate to it. The server's host key
+ is checked against the system host keys (see L{load_system_host_keys})
+ and any local host keys (L{load_host_keys}). If the server's hostname
+ is not found in either set of host keys, the missing host key policy
+ is used (see L{set_missing_host_key_policy}). The default policy is
+ to reject the key and raise an L{SSHException}.
+
+ Authentication is attempted in the following order of priority:
+
+ - The C{pkey} or C{key_filename} passed in (if any)
+ - Any key we can find through an SSH agent
+ - Any "id_rsa" or "id_dsa" key discoverable in C{~/.ssh/}
+ - Plain username/password auth, if a password was given
+
+ If a private key requires a password to unlock it, and a password is
+ passed in, that password will be used to attempt to unlock the key.
+
+ @param hostname: the server to connect to
+ @type hostname: str
+ @param port: the server port to connect to
+ @type port: int
+ @param username: the username to authenticate as (defaults to the
+ current local username)
+ @type username: str
+ @param password: a password to use for authentication or for unlocking
+ a private key
+ @type password: str
+ @param pkey: an optional private key to use for authentication
+ @type pkey: L{PKey}
+ @param key_filename: the filename, or list of filenames, of optional
+ private key(s) to try for authentication
+ @type key_filename: str or list(str)
+ @param timeout: an optional timeout (in seconds) for the TCP connect
+ @type timeout: float
+ @param allow_agent: set to False to disable connecting to the SSH agent
+ @type allow_agent: bool
+ @param look_for_keys: set to False to disable searching for discoverable
+ private key files in C{~/.ssh/}
+ @type look_for_keys: bool
+
+ @raise BadHostKeyException: if the server's host key could not be
+ verified
+ @raise AuthenticationException: if authentication failed
+ @raise SSHException: if there was any other error connecting or
+ establishing an SSH session
+ @raise socket.error: if a socket error occurred while connecting
+ """
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if timeout is not None:
+ try:
+ sock.settimeout(timeout)
+ except:
+ pass
+
+ sock.connect((hostname, port))
+ t = self._transport = Transport(sock)
+
+ if self._log_channel is not None:
+ t.set_log_channel(self._log_channel)
+ t.start_client()
+ ResourceManager.register(self, t)
+
+ server_key = t.get_remote_server_key()
+ keytype = server_key.get_name()
+
+ our_server_key = self._system_host_keys.get(hostname, {}).get(keytype, None)
+ if our_server_key is None:
+ our_server_key = self._host_keys.get(hostname, {}).get(keytype, None)
+ if our_server_key is None:
+ # will raise exception if the key is rejected; let that fall out
+ self._policy.missing_host_key(self, hostname, server_key)
+ # if the callback returns, assume the key is ok
+ our_server_key = server_key
+
+ if server_key != our_server_key:
+ raise BadHostKeyException(hostname, server_key, our_server_key)
+
+ if username is None:
+ username = getpass.getuser()
+
+ if key_filename is None:
+ key_filenames = []
+ elif isinstance(key_filename, (str, unicode)):
+ key_filenames = [ key_filename ]
+ else:
+ key_filenames = key_filename
+ self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys)
+
+ def close(self):
+ """
+ Close this SSHClient and its underlying L{Transport}.
+ """
+ if self._transport is None:
+ return
+ self._transport.close()
+ self._transport = None
+
+ def exec_command(self, command, bufsize=-1):
+ """
+ Execute a command on the SSH server. A new L{Channel} is opened and
+ the requested command is executed. The command's input and output
+ streams are returned as python C{file}-like objects representing
+ stdin, stdout, and stderr.
+
+ @param command: the command to execute
+ @type command: str
+ @param bufsize: interpreted the same way as by the built-in C{file()} function in python
+ @type bufsize: int
+ @return: the stdin, stdout, and stderr of the executing command
+ @rtype: tuple(L{ChannelFile}, L{ChannelFile}, L{ChannelFile})
+
+ @raise SSHException: if the server fails to execute the command
+ """
+ chan = self._transport.open_session()
+ chan.exec_command(command)
+ stdin = chan.makefile('wb', bufsize)
+ stdout = chan.makefile('rb', bufsize)
+ stderr = chan.makefile_stderr('rb', bufsize)
+ return stdin, stdout, stderr
+
+ def invoke_shell(self, term='vt100', width=80, height=24):
+ """
+ Start an interactive shell session on the SSH server. A new L{Channel}
+ is opened and connected to a pseudo-terminal using the requested
+ terminal type and size.
+
+ @param term: the terminal type to emulate (for example, C{"vt100"})
+ @type term: str
+ @param width: the width (in characters) of the terminal window
+ @type width: int
+ @param height: the height (in characters) of the terminal window
+ @type height: int
+ @return: a new channel connected to the remote shell
+ @rtype: L{Channel}
+
+ @raise SSHException: if the server fails to invoke a shell
+ """
+ chan = self._transport.open_session()
+ chan.get_pty(term, width, height)
+ chan.invoke_shell()
+ return chan
+
+ def open_sftp(self):
+ """
+ Open an SFTP session on the SSH server.
+
+ @return: a new SFTP session object
+ @rtype: L{SFTPClient}
+ """
+ return self._transport.open_sftp_client()
+
+ def get_transport(self):
+ """
+ Return the underlying L{Transport} object for this SSH connection.
+ This can be used to perform lower-level tasks, like opening specific
+ kinds of channels.
+
+ @return: the Transport for this connection
+ @rtype: L{Transport}
+ """
+ return self._transport
+
+ def _auth(self, username, password, pkey, key_filenames, allow_agent, look_for_keys):
+ """
+ Try, in order:
+
+ - The key passed in, if one was passed in.
+ - Any key we can find through an SSH agent (if allowed).
+ - Any "id_rsa" or "id_dsa" key discoverable in ~/.ssh/ (if allowed).
+ - Plain username/password auth, if a password was given.
+
+ (The password might be needed to unlock a private key.)
+ """
+ saved_exception = None
+
+ if pkey is not None:
+ try:
+ self._log(DEBUG, 'Trying SSH key %s' % hexlify(pkey.get_fingerprint()))
+ self._transport.auth_publickey(username, pkey)
+ return
+ except SSHException, e:
+ saved_exception = e
+
+ for key_filename in key_filenames:
+ for pkey_class in (RSAKey, DSSKey):
+ try:
+ key = pkey_class.from_private_key_file(key_filename, password)
+ self._log(DEBUG, 'Trying key %s from %s' % (hexlify(key.get_fingerprint()), key_filename))
+ self._transport.auth_publickey(username, key)
+ return
+ except SSHException, e:
+ saved_exception = e
+
+ if allow_agent:
+ for key in Agent().get_keys():
+ try:
+ self._log(DEBUG, 'Trying SSH agent key %s' % hexlify(key.get_fingerprint()))
+ self._transport.auth_publickey(username, key)
+ return
+ except SSHException, e:
+ saved_exception = e
+
+ keyfiles = []
+ rsa_key = os.path.expanduser('~/.ssh/id_rsa')
+ dsa_key = os.path.expanduser('~/.ssh/id_dsa')
+ if os.path.isfile(rsa_key):
+ keyfiles.append((RSAKey, rsa_key))
+ if os.path.isfile(dsa_key):
+ keyfiles.append((DSSKey, dsa_key))
+ # look in ~/ssh/ for windows users:
+ rsa_key = os.path.expanduser('~/ssh/id_rsa')
+ dsa_key = os.path.expanduser('~/ssh/id_dsa')
+ if os.path.isfile(rsa_key):
+ keyfiles.append((RSAKey, rsa_key))
+ if os.path.isfile(dsa_key):
+ keyfiles.append((DSSKey, dsa_key))
+
+ if not look_for_keys:
+ keyfiles = []
+
+ for pkey_class, filename in keyfiles:
+ try:
+ key = pkey_class.from_private_key_file(filename, password)
+ self._log(DEBUG, 'Trying discovered key %s in %s' % (hexlify(key.get_fingerprint()), filename))
+ self._transport.auth_publickey(username, key)
+ return
+ except SSHException, e:
+ saved_exception = e
+ except IOError, e:
+ saved_exception = e
+
+ if password is not None:
+ try:
+ self._transport.auth_password(username, password)
+ return
+ except SSHException, e:
+ saved_exception = e
+
+ # if we got an auth-failed exception earlier, re-raise it
+ if saved_exception is not None:
+ raise saved_exception
+ raise SSHException('No authentication methods available')
+
+ def _log(self, level, msg):
+ self._transport._log(level, msg)
+
diff --git a/paramiko/common.py b/paramiko/common.py
index c5999e6..f4a4d81 100644
--- a/paramiko/common.py
+++ b/paramiko/common.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -21,7 +21,7 @@ Common constants and global variables.
"""
MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_SERVICE_REQUEST, \
- MSG_SERVICE_ACCEPT = range(1, 7)
+ MSG_SERVICE_ACCEPT = range(1, 7)
MSG_KEXINIT, MSG_NEWKEYS = range(20, 22)
MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \
MSG_USERAUTH_BANNER = range(50, 54)
@@ -29,9 +29,9 @@ MSG_USERAUTH_PK_OK = 60
MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62)
MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83)
MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
- MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \
- MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
- MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
+ MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \
+ MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
+ MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
# for debugging:
@@ -95,21 +95,10 @@ CONNECTION_FAILED_CODE = {
DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \
DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14
-
-from Crypto.Util.randpool import PersistentRandomPool, RandomPool
+from rng import StrongLockingRandomPool
# keep a crypto-strong PRNG nearby
-try:
- randpool = PersistentRandomPool(os.path.join(os.path.expanduser('~'), '/.randpool'))
-except:
- # the above will likely fail on Windows - fall back to non-persistent random pool
- randpool = RandomPool()
-
-try:
- randpool.randomize()
-except:
- # earlier versions of pyCrypto (pre-2.0) don't have randomize()
- pass
+randpool = StrongLockingRandomPool()
import sys
if sys.version_info < (2, 3):
@@ -129,6 +118,7 @@ else:
import logging
PY22 = False
+
DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
diff --git a/paramiko/compress.py b/paramiko/compress.py
index bdf4b42..08fffb1 100644
--- a/paramiko/compress.py
+++ b/paramiko/compress.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
diff --git a/paramiko/config.py b/paramiko/config.py
new file mode 100644
index 0000000..1e3d680
--- /dev/null
+++ b/paramiko/config.py
@@ -0,0 +1,105 @@
+# Copyright (C) 2006-2007 Robey Pointer <robey@lag.net>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+
+"""
+L{SSHConfig}.
+"""
+
+import fnmatch
+
+
+class SSHConfig (object):
+ """
+ Representation of config information as stored in the format used by
+ OpenSSH. Queries can be made via L{lookup}. The format is described in
+ OpenSSH's C{ssh_config} man page. This class is provided primarily as a
+ convenience to posix users (since the OpenSSH format is a de-facto
+ standard on posix) but should work fine on Windows too.
+
+ @since: 1.6
+ """
+
+ def __init__(self):
+ """
+ Create a new OpenSSH config object.
+ """
+ self._config = [ { 'host': '*' } ]
+
+ def parse(self, file_obj):
+ """
+ Read an OpenSSH config from the given file object.
+
+ @param file_obj: a file-like object to read the config file from
+ @type file_obj: file
+ """
+ config = self._config[0]
+ for line in file_obj:
+ line = line.rstrip('\n').lstrip()
+ if (line == '') or (line[0] == '#'):
+ continue
+ if '=' in line:
+ key, value = line.split('=', 1)
+ key = key.strip().lower()
+ else:
+ # find first whitespace, and split there
+ i = 0
+ while (i < len(line)) and not line[i].isspace():
+ i += 1
+ if i == len(line):
+ raise Exception('Unparsable line: %r' % line)
+ key = line[:i].lower()
+ value = line[i:].lstrip()
+
+ if key == 'host':
+ # do we have a pre-existing host config to append to?
+ matches = [c for c in self._config if c['host'] == value]
+ if len(matches) > 0:
+ config = matches[0]
+ else:
+ config = { 'host': value }
+ self._config.append(config)
+ else:
+ config[key] = value
+
+ def lookup(self, hostname):
+ """
+ Return a dict of config options for a given hostname.
+
+ The host-matching rules of OpenSSH's C{ssh_config} man page are used,
+ which means that all configuration options from matching host
+ specifications are merged, with more specific hostmasks taking
+ precedence. In other words, if C{"Port"} is set under C{"Host *"}
+ and also C{"Host *.example.com"}, and the lookup is for
+ C{"ssh.example.com"}, then the port entry for C{"Host *.example.com"}
+ will win out.
+
+ The keys in the returned dict are all normalized to lowercase (look for
+ C{"port"}, not C{"Port"}. No other processing is done to the keys or
+ values.
+
+ @param hostname: the hostname to lookup
+ @type hostname: str
+ """
+ matches = [x for x in self._config if fnmatch.fnmatch(hostname, x['host'])]
+ # sort in order of shortest match (usually '*') to longest
+ matches.sort(lambda x,y: cmp(len(x['host']), len(y['host'])))
+ ret = {}
+ for m in matches:
+ ret.update(m)
+ del ret['host']
+ return ret
diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py
index 2b31372..9f381d2 100644
--- a/paramiko/dsskey.py
+++ b/paramiko/dsskey.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -37,7 +37,15 @@ class DSSKey (PKey):
data.
"""
- def __init__(self, msg=None, data=None, filename=None, password=None, vals=None):
+ def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None):
+ self.p = None
+ self.q = None
+ self.g = None
+ self.y = None
+ self.x = None
+ if file_obj is not None:
+ self._from_private_key(file_obj, password)
+ return
if filename is not None:
self._from_private_key_file(filename, password)
return
@@ -81,7 +89,7 @@ class DSSKey (PKey):
return self.size
def can_sign(self):
- return hasattr(self, 'x')
+ return self.x is not None
def sign_ssh_data(self, rpool, data):
digest = SHA.new(data).digest()
@@ -123,14 +131,22 @@ class DSSKey (PKey):
dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q)))
return dss.verify(sigM, (sigR, sigS))
- def write_private_key_file(self, filename, password=None):
+ def _encode_key(self):
+ if self.x is None:
+ raise SSHException('Not enough key information')
keylist = [ 0, self.p, self.q, self.g, self.y, self.x ]
try:
b = BER()
b.encode(keylist)
except BERException:
raise SSHException('Unable to create ber encoding of key')
- self._write_private_key_file('DSA', filename, str(b), password)
+ return str(b)
+
+ def write_private_key_file(self, filename, password=None):
+ self._write_private_key_file('DSA', filename, self._encode_key(), password)
+
+ def write_private_key(self, file_obj, password=None):
+ self._write_private_key('DSA', file_obj, self._encode_key(), password)
def generate(bits=1024, progress_func=None):
"""
@@ -144,8 +160,6 @@ class DSSKey (PKey):
@type progress_func: function
@return: new private key
@rtype: L{DSSKey}
-
- @since: fearow
"""
randpool.stir()
dsa = DSA.generate(bits, randpool.get_bytes, progress_func)
@@ -159,9 +173,16 @@ class DSSKey (PKey):
def _from_private_key_file(self, filename, password):
+ data = self._read_private_key_file('DSA', filename, password)
+ self._decode_key(data)
+
+ def _from_private_key(self, file_obj, password):
+ data = self._read_private_key('DSA', file_obj, password)
+ self._decode_key(data)
+
+ def _decode_key(self, data):
# private key file contains:
# DSAPrivateKey = { version = 0, p, q, g, y, x }
- data = self._read_private_key_file('DSA', filename, password)
try:
keylist = BER(data).decode()
except BERException, x:
diff --git a/paramiko/file.py b/paramiko/file.py
index c29e7c4..7db4401 100644
--- a/paramiko/file.py
+++ b/paramiko/file.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -23,15 +23,6 @@ BufferedFile.
from cStringIO import StringIO
-_FLAG_READ = 0x1
-_FLAG_WRITE = 0x2
-_FLAG_APPEND = 0x4
-_FLAG_BINARY = 0x10
-_FLAG_BUFFERED = 0x20
-_FLAG_LINE_BUFFERED = 0x40
-_FLAG_UNIVERSAL_NEWLINE = 0x80
-
-
class BufferedFile (object):
"""
Reusable base class to implement python-style file buffering around a
@@ -44,7 +35,16 @@ class BufferedFile (object):
SEEK_CUR = 1
SEEK_END = 2
+ FLAG_READ = 0x1
+ FLAG_WRITE = 0x2
+ FLAG_APPEND = 0x4
+ FLAG_BINARY = 0x10
+ FLAG_BUFFERED = 0x20
+ FLAG_LINE_BUFFERED = 0x40
+ FLAG_UNIVERSAL_NEWLINE = 0x80
+
def __init__(self):
+ self.newlines = None
self._flags = 0
self._bufsize = self._DEFAULT_BUFSIZE
self._wbuffer = StringIO()
@@ -55,6 +55,8 @@ class BufferedFile (object):
# realpos - position according the OS
# (these may be different because we buffer for line reading)
self._pos = self._realpos = 0
+ # size only matters for seekable files
+ self._size = 0
def __del__(self):
self.close()
@@ -112,16 +114,16 @@ class BufferedFile (object):
file first). If the C{size} argument is negative or omitted, read all
the remaining data in the file.
- @param size: maximum number of bytes to read.
+ @param size: maximum number of bytes to read
@type size: int
@return: data read from the file, or an empty string if EOF was
- encountered immediately.
+ encountered immediately
@rtype: str
"""
if self._closed:
raise IOError('File is closed')
- if not (self._flags & _FLAG_READ):
- raise IOError('File not open for reading')
+ if not (self._flags & self.FLAG_READ):
+ raise IOError('File is not open for reading')
if (size is None) or (size < 0):
# go for broke
result = self._rbuffer
@@ -144,8 +146,11 @@ class BufferedFile (object):
self._pos += len(result)
return result
while len(self._rbuffer) < size:
+ read_size = size - len(self._rbuffer)
+ if self._flags & self.FLAG_BUFFERED:
+ read_size = max(self._bufsize, read_size)
try:
- new_data = self._read(max(self._bufsize, size - len(self._rbuffer)))
+ new_data = self._read(read_size)
except EOFError:
new_data = None
if (new_data is None) or (len(new_data) == 0):
@@ -178,11 +183,11 @@ class BufferedFile (object):
# it's almost silly how complex this function is.
if self._closed:
raise IOError('File is closed')
- if not (self._flags & _FLAG_READ):
+ if not (self._flags & self.FLAG_READ):
raise IOError('File not open for reading')
line = self._rbuffer
while True:
- if self._at_trailing_cr and (self._flags & _FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0):
+ if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0):
# edge case: the newline may be '\r\n' and we may have read
# only the first '\r' last time.
if line[0] == '\n':
@@ -202,8 +207,8 @@ class BufferedFile (object):
return line
n = size - len(line)
else:
- n = self._DEFAULT_BUFSIZE
- if ('\n' in line) or ((self._flags & _FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)):
+ n = self._bufsize
+ if ('\n' in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)):
break
try:
new_data = self._read(n)
@@ -217,7 +222,7 @@ class BufferedFile (object):
self._realpos += len(new_data)
# find the newline
pos = line.find('\n')
- if self._flags & _FLAG_UNIVERSAL_NEWLINE:
+ if self._flags & self.FLAG_UNIVERSAL_NEWLINE:
rpos = line.find('\r')
if (rpos >= 0) and ((rpos < pos) or (pos < 0)):
pos = rpos
@@ -250,7 +255,7 @@ class BufferedFile (object):
"""
lines = []
bytes = 0
- while 1:
+ while True:
line = self.readline()
if len(line) == 0:
break
@@ -303,13 +308,13 @@ class BufferedFile (object):
"""
if self._closed:
raise IOError('File is closed')
- if not (self._flags & _FLAG_WRITE):
+ if not (self._flags & self.FLAG_WRITE):
raise IOError('File not open for writing')
- if not (self._flags & _FLAG_BUFFERED):
+ if not (self._flags & self.FLAG_BUFFERED):
self._write_all(data)
return
self._wbuffer.write(data)
- if self._flags & _FLAG_LINE_BUFFERED:
+ if self._flags & self.FLAG_LINE_BUFFERED:
# only scan the new data for linefeed, to avoid wasting time.
last_newline_pos = data.rfind('\n')
if last_newline_pos >= 0:
@@ -387,26 +392,37 @@ class BufferedFile (object):
"""
Subclasses call this method to initialize the BufferedFile.
"""
+ # set bufsize in any event, because it's used for readline().
+ self._bufsize = self._DEFAULT_BUFSIZE
+ if bufsize < 0:
+ # do no buffering by default, because otherwise writes will get
+ # buffered in a way that will probably confuse people.
+ bufsize = 0
if bufsize == 1:
# apparently, line buffering only affects writes. reads are only
# buffered if you call readline (directly or indirectly: iterating
# over a file will indirectly call readline).
- self._flags |= _FLAG_BUFFERED | _FLAG_LINE_BUFFERED
+ self._flags |= self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED
elif bufsize > 1:
self._bufsize = bufsize
- self._flags |= _FLAG_BUFFERED
+ self._flags |= self.FLAG_BUFFERED
+ self._flags &= ~self.FLAG_LINE_BUFFERED
+ elif bufsize == 0:
+ # unbuffered
+ self._flags &= ~(self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED)
+
if ('r' in mode) or ('+' in mode):
- self._flags |= _FLAG_READ
+ self._flags |= self.FLAG_READ
if ('w' in mode) or ('+' in mode):
- self._flags |= _FLAG_WRITE
+ self._flags |= self.FLAG_WRITE
if ('a' in mode):
- self._flags |= _FLAG_WRITE | _FLAG_APPEND
+ self._flags |= self.FLAG_WRITE | self.FLAG_APPEND
self._size = self._get_size()
self._pos = self._realpos = self._size
if ('b' in mode):
- self._flags |= _FLAG_BINARY
+ self._flags |= self.FLAG_BINARY
if ('U' in mode):
- self._flags |= _FLAG_UNIVERSAL_NEWLINE
+ self._flags |= self.FLAG_UNIVERSAL_NEWLINE
# built-in file objects have this attribute to store which kinds of
# line terminations they've seen:
# <http://www.python.org/doc/current/lib/built-in-funcs.html>
@@ -418,7 +434,7 @@ class BufferedFile (object):
while len(data) > 0:
count = self._write(data)
data = data[count:]
- if self._flags & _FLAG_APPEND:
+ if self._flags & self.FLAG_APPEND:
self._size += count
self._pos = self._realpos = self._size
else:
@@ -430,7 +446,7 @@ class BufferedFile (object):
# silliness about tracking what kinds of newlines we've seen.
# i don't understand why it can be None, a string, or a tuple, instead
# of just always being a tuple, but we'll emulate that behavior anyway.
- if not (self._flags & _FLAG_UNIVERSAL_NEWLINE):
+ if not (self._flags & self.FLAG_UNIVERSAL_NEWLINE):
return
if self.newlines is None:
self.newlines = newline
diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py
new file mode 100644
index 0000000..0c0ac8c
--- /dev/null
+++ b/paramiko/hostkeys.py
@@ -0,0 +1,315 @@
+# Copyright (C) 2006-2007 Robey Pointer <robey@lag.net>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+
+"""
+L{HostKeys}
+"""
+
+import base64
+from Crypto.Hash import SHA, HMAC
+import UserDict
+
+from paramiko.common import *
+from paramiko.dsskey import DSSKey
+from paramiko.rsakey import RSAKey
+
+
+class HostKeyEntry:
+ """
+ Representation of a line in an OpenSSH-style "known hosts" file.
+ """
+
+ def __init__(self, hostnames=None, key=None):
+ self.valid = (hostnames is not None) and (key is not None)
+ self.hostnames = hostnames
+ self.key = key
+
+ def from_line(cls, line):
+ """
+ Parses the given line of text to find the names for the host,
+ the type of key, and the key data. The line is expected to be in the
+ format used by the openssh known_hosts file.
+
+ Lines are expected to not have leading or trailing whitespace.
+ We don't bother to check for comments or empty lines. All of
+ that should be taken care of before sending the line to us.
+
+ @param line: a line from an OpenSSH known_hosts file
+ @type line: str
+ """
+ fields = line.split(' ')
+ if len(fields) != 3:
+ # Bad number of fields
+ return None
+
+ names, keytype, key = fields
+ names = names.split(',')
+
+ # Decide what kind of key we're looking at and create an object
+ # to hold it accordingly.
+ if keytype == 'ssh-rsa':
+ key = RSAKey(data=base64.decodestring(key))
+ elif keytype == 'ssh-dss':
+ key = DSSKey(data=base64.decodestring(key))
+ else:
+ return None
+
+ return cls(names, key)
+ from_line = classmethod(from_line)
+
+ def to_line(self):
+ """
+ Returns a string in OpenSSH known_hosts file format, or None if
+ the object is not in a valid state. A trailing newline is
+ included.
+ """
+ if self.valid:
+ return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(),
+ self.key.get_base64())
+ return None
+
+ def __repr__(self):
+ return '<HostKeyEntry %r: %r>' % (self.hostnames, self.key)
+
+
+class HostKeys (UserDict.DictMixin):
+ """
+ Representation of an openssh-style "known hosts" file. Host keys can be
+ read from one or more files, and then individual hosts can be looked up to
+ verify server keys during SSH negotiation.
+
+ A HostKeys object can be treated like a dict; any dict lookup is equivalent
+ to calling L{lookup}.
+
+ @since: 1.5.3
+ """
+
+ def __init__(self, filename=None):
+ """
+ Create a new HostKeys object, optionally loading keys from an openssh
+ style host-key file.
+
+ @param filename: filename to load host keys from, or C{None}
+ @type filename: str
+ """
+ # emulate a dict of { hostname: { keytype: PKey } }
+ self._entries = []
+ if filename is not None:
+ self.load(filename)
+
+ def add(self, hostname, keytype, key):
+ """
+ Add a host key entry to the table. Any existing entry for a
+ C{(hostname, keytype)} pair will be replaced.
+
+ @param hostname: the hostname (or IP) to add
+ @type hostname: str
+ @param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"})
+ @type keytype: str
+ @param key: the key to add
+ @type key: L{PKey}
+ """
+ for e in self._entries:
+ if (hostname in e.hostnames) and (e.key.get_name() == keytype):
+ e.key = key
+ return
+ self._entries.append(HostKeyEntry([hostname], key))
+
+ def load(self, filename):
+ """
+ Read a file of known SSH host keys, in the format used by openssh.
+ This type of file unfortunately doesn't exist on Windows, but on
+ posix, it will usually be stored in
+ C{os.path.expanduser("~/.ssh/known_hosts")}.
+
+ If this method is called multiple times, the host keys are merged,
+ not cleared. So multiple calls to C{load} will just call L{add},
+ replacing any existing entries and adding new ones.
+
+ @param filename: name of the file to read host keys from
+ @type filename: str
+
+ @raise IOError: if there was an error reading the file
+ """
+ f = open(filename, 'r')
+ for line in f:
+ line = line.strip()
+ if (len(line) == 0) or (line[0] == '#'):
+ continue
+ e = HostKeyEntry.from_line(line)
+ if e is not None:
+ self._entries.append(e)
+ f.close()
+
+ def save(self, filename):
+ """
+ Save host keys into a file, in the format used by openssh. The order of
+ keys in the file will be preserved when possible (if these keys were
+ loaded from a file originally). The single exception is that combined
+ lines will be split into individual key lines, which is arguably a bug.
+
+ @param filename: name of the file to write
+ @type filename: str
+
+ @raise IOError: if there was an error writing the file
+
+ @since: 1.6.1
+ """
+ f = open(filename, 'w')
+ for e in self._entries:
+ line = e.to_line()
+ if line:
+ f.write(line)
+ f.close()
+
+ def lookup(self, hostname):
+ """
+ Find a hostkey entry for a given hostname or IP. If no entry is found,
+ C{None} is returned. Otherwise a dictionary of keytype to key is
+ returned. The keytype will be either C{"ssh-rsa"} or C{"ssh-dss"}.
+
+ @param hostname: the hostname (or IP) to lookup
+ @type hostname: str
+ @return: keys associated with this host (or C{None})
+ @rtype: dict(str, L{PKey})
+ """
+ class SubDict (UserDict.DictMixin):
+ def __init__(self, hostname, entries, hostkeys):
+ self._hostname = hostname
+ self._entries = entries
+ self._hostkeys = hostkeys
+
+ def __getitem__(self, key):
+ for e in self._entries:
+ if e.key.get_name() == key:
+ return e.key
+ raise KeyError(key)
+
+ def __setitem__(self, key, val):
+ for e in self._entries:
+ if e.key is None:
+ continue
+ if e.key.get_name() == key:
+ # replace
+ e.key = val
+ break
+ else:
+ # add a new one
+ e = HostKeyEntry([hostname], val)
+ self._entries.append(e)
+ self._hostkeys._entries.append(e)
+
+ def keys(self):
+ return [e.key.get_name() for e in self._entries if e.key is not None]
+
+ entries = []
+ for e in self._entries:
+ for h in e.hostnames:
+ if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname):
+ entries.append(e)
+ if len(entries) == 0:
+ return None
+ return SubDict(hostname, entries, self)
+
+ def check(self, hostname, key):
+ """
+ Return True if the given key is associated with the given hostname
+ in this dictionary.
+
+ @param hostname: hostname (or IP) of the SSH server
+ @type hostname: str
+ @param key: the key to check
+ @type key: L{PKey}
+ @return: C{True} if the key is associated with the hostname; C{False}
+ if not
+ @rtype: bool
+ """
+ k = self.lookup(hostname)
+ if k is None:
+ return False
+ host_key = k.get(key.get_name(), None)
+ if host_key is None:
+ return False
+ return str(host_key) == str(key)
+
+ def clear(self):
+ """
+ Remove all host keys from the dictionary.
+ """
+ self._entries = []
+
+ def __getitem__(self, key):
+ ret = self.lookup(key)
+ if ret is None:
+ raise KeyError(key)
+ return ret
+
+ def __setitem__(self, hostname, entry):
+ # don't use this please.
+ if len(entry) == 0:
+ self._entries.append(HostKeyEntry([hostname], None))
+ return
+ for key_type in entry.keys():
+ found = False
+ for e in self._entries:
+ if (hostname in e.hostnames) and (e.key.get_name() == key_type):
+ # replace
+ e.key = entry[key_type]
+ found = True
+ if not found:
+ self._entries.append(HostKeyEntry([hostname], entry[key_type]))
+
+ def keys(self):
+ # python 2.4 sets would be nice here.
+ ret = []
+ for e in self._entries:
+ for h in e.hostnames:
+ if h not in ret:
+ ret.append(h)
+ return ret
+
+ def values(self):
+ ret = []
+ for k in self.keys():
+ ret.append(self.lookup(k))
+ return ret
+
+ def hash_host(hostname, salt=None):
+ """
+ Return a "hashed" form of the hostname, as used by openssh when storing
+ hashed hostnames in the known_hosts file.
+
+ @param hostname: the hostname to hash
+ @type hostname: str
+ @param salt: optional salt to use when hashing (must be 20 bytes long)
+ @type salt: str
+ @return: the hashed hostname
+ @rtype: str
+ """
+ if salt is None:
+ salt = randpool.get_bytes(SHA.digest_size)
+ else:
+ if salt.startswith('|1|'):
+ salt = salt.split('|')[2]
+ salt = base64.decodestring(salt)
+ assert len(salt) == SHA.digest_size
+ hmac = HMAC.HMAC(salt, hostname, SHA).digest()
+ hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac))
+ return hostkey.replace('\n', '')
+ hash_host = staticmethod(hash_host)
+
diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py
index 994d76c..63a0c99 100644
--- a/paramiko/kex_gex.py
+++ b/paramiko/kex_gex.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -31,7 +31,8 @@ from paramiko.message import Message
from paramiko.ssh_exception import SSHException
-_MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(31, 35)
+_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \
+ _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35)
class KexGex (object):
@@ -43,19 +44,32 @@ class KexGex (object):
def __init__(self, transport):
self.transport = transport
-
- def start_kex(self):
+ self.p = None
+ self.q = None
+ self.g = None
+ self.x = None
+ self.e = None
+ self.f = None
+ self.old_style = False
+
+ def start_kex(self, _test_old_style=False):
if self.transport.server_mode:
- self.transport._expect_packet(_MSG_KEXDH_GEX_REQUEST)
+ self.transport._expect_packet(_MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD)
return
# request a bit range: we accept (min_bits) to (max_bits), but prefer
# (preferred_bits). according to the spec, we shouldn't pull the
# minimum up above 1024.
m = Message()
- m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST))
- m.add_int(self.min_bits)
- m.add_int(self.preferred_bits)
- m.add_int(self.max_bits)
+ if _test_old_style:
+ # only used for unit tests: we shouldn't ever send this
+ m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST_OLD))
+ m.add_int(self.preferred_bits)
+ self.old_style = True
+ else:
+ m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST))
+ m.add_int(self.min_bits)
+ m.add_int(self.preferred_bits)
+ m.add_int(self.max_bits)
self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_GEX_GROUP)
@@ -68,6 +82,8 @@ class KexGex (object):
return self._parse_kexdh_gex_init(m)
elif ptype == _MSG_KEXDH_GEX_REPLY:
return self._parse_kexdh_gex_reply(m)
+ elif ptype == _MSG_KEXDH_GEX_REQUEST_OLD:
+ return self._parse_kexdh_gex_request_old(m)
raise SSHException('KexGex asked to handle packet type %d' % ptype)
@@ -126,6 +142,28 @@ class KexGex (object):
self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_GEX_INIT)
+ def _parse_kexdh_gex_request_old(self, m):
+ # same as above, but without min_bits or max_bits (used by older clients like putty)
+ self.preferred_bits = m.get_int()
+ # smoosh the user's preferred size into our own limits
+ if self.preferred_bits > self.max_bits:
+ self.preferred_bits = self.max_bits
+ if self.preferred_bits < self.min_bits:
+ self.preferred_bits = self.min_bits
+ # generate prime
+ pack = self.transport._get_modulus_pack()
+ if pack is None:
+ raise SSHException('Can\'t do server-side gex with no modulus pack')
+ self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,))
+ self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits)
+ m = Message()
+ m.add_byte(chr(_MSG_KEXDH_GEX_GROUP))
+ m.add_mpint(self.p)
+ m.add_mpint(self.g)
+ self.transport._send_message(m)
+ self.transport._expect_packet(_MSG_KEXDH_GEX_INIT)
+ self.old_style = True
+
def _parse_kexdh_gex_group(self, m):
self.p = m.get_mpint()
self.g = m.get_mpint()
@@ -156,9 +194,11 @@ class KexGex (object):
hm.add(self.transport.remote_version, self.transport.local_version,
self.transport.remote_kex_init, self.transport.local_kex_init,
key)
- hm.add_int(self.min_bits)
+ if not self.old_style:
+ hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
- hm.add_int(self.max_bits)
+ if not self.old_style:
+ hm.add_int(self.max_bits)
hm.add_mpint(self.p)
hm.add_mpint(self.g)
hm.add_mpint(self.e)
@@ -189,9 +229,11 @@ class KexGex (object):
hm.add(self.transport.local_version, self.transport.remote_version,
self.transport.local_kex_init, self.transport.remote_kex_init,
host_key)
- hm.add_int(self.min_bits)
+ if not self.old_style:
+ hm.add_int(self.min_bits)
hm.add_int(self.preferred_bits)
- hm.add_int(self.max_bits)
+ if not self.old_style:
+ hm.add_int(self.max_bits)
hm.add_mpint(self.p)
hm.add_mpint(self.g)
hm.add_mpint(self.e)
diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py
index a13cf3a..843a6d8 100644
--- a/paramiko/kex_group1.py
+++ b/paramiko/kex_group1.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
diff --git a/paramiko/logging22.py b/paramiko/logging22.py
index ac11a73..9bf7656 100644
--- a/paramiko/logging22.py
+++ b/paramiko/logging22.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
diff --git a/paramiko/message.py b/paramiko/message.py
index 1d75a01..1a5151c 100644
--- a/paramiko/message.py
+++ b/paramiko/message.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -285,7 +285,7 @@ class Message (object):
elif type(i) is list:
return self.add_list(i)
else:
- raise exception('Unknown type')
+ raise Exception('Unknown type')
def add(self, *seq):
"""
diff --git a/paramiko/packet.py b/paramiko/packet.py
index 277d68e..4bde2f7 100644
--- a/paramiko/packet.py
+++ b/paramiko/packet.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -20,12 +20,12 @@
Packetizer.
"""
+import errno
import select
import socket
import struct
import threading
import time
-from Crypto.Hash import HMAC
from paramiko.common import *
from paramiko import util
@@ -33,6 +33,19 @@ from paramiko.ssh_exception import SSHException
from paramiko.message import Message
+got_r_hmac = False
+try:
+ import r_hmac
+ got_r_hmac = True
+except ImportError:
+ pass
+def compute_hmac(key, message, digest_class):
+ if got_r_hmac:
+ return r_hmac.HMAC(key, message, digest_class).digest()
+ from Crypto.Hash import HMAC
+ return HMAC.HMAC(key, message, digest_class).digest()
+
+
class NeedRekeyException (Exception):
pass
@@ -54,6 +67,7 @@ class Packetizer (object):
self.__dump_packets = False
self.__need_rekey = False
self.__init_count = 0
+ self.__remainder = ''
# used for noticing when to re-key:
self.__sent_bytes = 0
@@ -86,13 +100,6 @@ class Packetizer (object):
self.__keepalive_last = time.time()
self.__keepalive_callback = None
- def __del__(self):
- # this is not guaranteed to be called, but we should try.
- try:
- self.__socket.close()
- except:
- pass
-
def set_log(self, log):
"""
Set the python log object to use for logging.
@@ -142,6 +149,7 @@ class Packetizer (object):
def close(self):
self.__closed = True
+ self.__socket.close()
def set_hexdump(self, hexdump):
self.__dump_packets = hexdump
@@ -186,10 +194,16 @@ class Packetizer (object):
@raise EOFError: if the socket was closed before all the bytes could
be read
"""
- if PY22:
- return self._py22_read_all(n)
out = ''
+ # handle over-reading from reading the banner line
+ if len(self.__remainder) > 0:
+ out = self.__remainder[:n]
+ self.__remainder = self.__remainder[n:]
+ n -= len(out)
+ if PY22:
+ return self._py22_read_all(n, out)
while n > 0:
+ got_timeout = False
try:
x = self.__socket.recv(n)
if len(x) == 0:
@@ -197,6 +211,21 @@ class Packetizer (object):
out += x
n -= len(x)
except socket.timeout:
+ got_timeout = True
+ except socket.error, e:
+ # on Linux, sometimes instead of socket.timeout, we get
+ # EAGAIN. this is a bug in recent (> 2.6.9) kernels but
+ # we need to work around it.
+ if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
+ got_timeout = True
+ elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
+ # syscall interrupted; try again
+ pass
+ elif self.__closed:
+ raise EOFError()
+ else:
+ raise
+ if got_timeout:
if self.__closed:
raise EOFError()
if check_rekey and (len(out) == 0) and self.__need_rekey:
@@ -207,32 +236,44 @@ class Packetizer (object):
def write_all(self, out):
self.__keepalive_last = time.time()
while len(out) > 0:
+ got_timeout = False
try:
n = self.__socket.send(out)
except socket.timeout:
- n = 0
- if self.__closed:
+ got_timeout = True
+ except socket.error, e:
+ if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
+ got_timeout = True
+ elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
+ # syscall interrupted; try again
+ pass
+ else:
n = -1
except Exception:
# could be: (32, 'Broken pipe')
n = -1
+ if got_timeout:
+ n = 0
+ if self.__closed:
+ n = -1
if n < 0:
raise EOFError()
if n == len(out):
- return
+ break
out = out[n:]
return
def readline(self, timeout):
"""
- Read a line from the socket. This is done in a fairly inefficient
- way, but is only used for initial banner negotiation so it's not worth
- optimising.
+ Read a line from the socket. We assume no data is pending after the
+ line, so it's okay to attempt large reads.
"""
buf = ''
while not '\n' in buf:
buf += self._read_timeout(timeout)
- buf = buf[:-1]
+ n = buf.index('\n')
+ self.__remainder += buf[n+1:]
+ buf = buf[:n]
if (len(buf) > 0) and (buf[-1] == '\r'):
buf = buf[:-1]
return buf
@@ -242,21 +283,21 @@ class Packetizer (object):
Write a block of data using the current cipher, as an SSH block.
"""
# encrypt this sucka
- randpool.stir()
data = str(data)
cmd = ord(data[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
- self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data)))
- if self.__compress_engine_out is not None:
- data = self.__compress_engine_out(data)
- packet = self._build_packet(data)
- if self.__dump_packets:
- self._log(DEBUG, util.format_binary(packet, 'OUT: '))
+ orig_len = len(data)
self.__write_lock.acquire()
try:
+ if self.__compress_engine_out is not None:
+ data = self.__compress_engine_out(data)
+ packet = self._build_packet(data)
+ if self.__dump_packets:
+ self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, orig_len))
+ self._log(DEBUG, util.format_binary(packet, 'OUT: '))
if self.__block_engine_out != None:
out = self.__block_engine_out.encrypt(packet)
else:
@@ -264,12 +305,15 @@ class Packetizer (object):
# + mac
if self.__block_engine_out != None:
payload = struct.pack('>I', self.__sequence_number_out) + packet
- out += HMAC.HMAC(self.__mac_key_out, payload, self.__mac_engine_out).digest()[:self.__mac_size_out]
+ out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out]
self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL
self.write_all(out)
self.__sent_bytes += len(out)
self.__sent_packets += 1
+ if (self.__sent_packets % 100) == 0:
+ # stirring the randpool takes 30ms on my ibook!!
+ randpool.stir()
if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \
and not self.__need_rekey:
# only ask once for rekeying
@@ -310,12 +354,12 @@ class Packetizer (object):
if self.__mac_size_in > 0:
mac = post_packet[:self.__mac_size_in]
mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet
- my_mac = HMAC.HMAC(self.__mac_key_in, mac_payload, self.__mac_engine_in).digest()[:self.__mac_size_in]
+ my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in]
if my_mac != mac:
raise SSHException('Mismatched MAC')
padding = ord(packet[0])
payload = packet[1:packet_size - padding]
- randpool.add_event(packet[packet_size - padding])
+ randpool.add_event()
if self.__dump_packets:
self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
@@ -348,7 +392,8 @@ class Packetizer (object):
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
- self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
+ if self.__dump_packets:
+ self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
return cmd, msg
@@ -374,8 +419,7 @@ class Packetizer (object):
self.__keepalive_callback()
self.__keepalive_last = now
- def _py22_read_all(self, n):
- out = ''
+ def _py22_read_all(self, n, out):
while n > 0:
r, w, e = select.select([self.__socket], [], [], 0.1)
if self.__socket not in r:
@@ -398,23 +442,24 @@ class Packetizer (object):
x = self.__socket.recv(1)
if len(x) == 0:
raise EOFError()
- return x
+ break
if self.__closed:
raise EOFError()
now = time.time()
if now - start >= timeout:
raise socket.timeout()
+ return x
def _read_timeout(self, timeout):
if PY22:
- return self._py22_read_timeout(n)
+ return self._py22_read_timeout(timeout)
start = time.time()
while True:
try:
- x = self.__socket.recv(1)
+ x = self.__socket.recv(128)
if len(x) == 0:
raise EOFError()
- return x
+ break
except socket.timeout:
pass
if self.__closed:
@@ -422,6 +467,7 @@ class Packetizer (object):
now = time.time()
if now - start >= timeout:
raise socket.timeout()
+ return x
def _build_packet(self, payload):
# pad up at least 4 bytes, to nearest block-size (usually 8)
diff --git a/paramiko/pipe.py b/paramiko/pipe.py
index cc28f43..1cfed2d 100644
--- a/paramiko/pipe.py
+++ b/paramiko/pipe.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -19,6 +19,9 @@
"""
Abstraction of a one-way pipe where the read end can be used in select().
Normally this is trivial, but Windows makes it nearly impossible.
+
+The pipe acts like an Event, which can be set or cleared. When set, the pipe
+will trigger as readable in select().
"""
import sys
@@ -28,8 +31,10 @@ import socket
def make_pipe ():
if sys.platform[:3] != 'win':
- return PosixPipe()
- return WindowsPipe()
+ p = PosixPipe()
+ else:
+ p = WindowsPipe()
+ return p
class PosixPipe (object):
@@ -37,10 +42,13 @@ class PosixPipe (object):
self._rfd, self._wfd = os.pipe()
self._set = False
self._forever = False
+ self._closed = False
def close (self):
os.close(self._rfd)
os.close(self._wfd)
+ # used for unit tests:
+ self._closed = True
def fileno (self):
return self._rfd
@@ -52,7 +60,7 @@ class PosixPipe (object):
self._set = False
def set (self):
- if self._set:
+ if self._set or self._closed:
return
self._set = True
os.write(self._wfd, '*')
@@ -80,10 +88,13 @@ class WindowsPipe (object):
serv.close()
self._set = False
self._forever = False
+ self._closed = False
def close (self):
self._rsock.close()
self._wsock.close()
+ # used for unit tests:
+ self._closed = True
def fileno (self):
return self._rsock.fileno()
@@ -95,7 +106,7 @@ class WindowsPipe (object):
self._set = False
def set (self):
- if self._set:
+ if self._set or self._closed:
return
self._set = True
self._wsock.send('*')
@@ -103,3 +114,34 @@ class WindowsPipe (object):
def set_forever (self):
self._forever = True
self.set()
+
+
+class OrPipe (object):
+ def __init__(self, pipe):
+ self._set = False
+ self._partner = None
+ self._pipe = pipe
+
+ def set(self):
+ self._set = True
+ if not self._partner._set:
+ self._pipe.set()
+
+ def clear(self):
+ self._set = False
+ if not self._partner._set:
+ self._pipe.clear()
+
+
+def make_or_pipe(pipe):
+ """
+ wraps a pipe into two pipe-like objects which are "or"d together to
+ affect the real pipe. if either returned pipe is set, the wrapped pipe
+ is set. when both are cleared, the wrapped pipe is cleared.
+ """
+ p1 = OrPipe(pipe)
+ p2 = OrPipe(pipe)
+ p1._partner = p2
+ p2._partner = p1
+ return p1, p2
+
diff --git a/paramiko/pkey.py b/paramiko/pkey.py
index 75db8e5..4e8b26b 100644
--- a/paramiko/pkey.py
+++ b/paramiko/pkey.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -21,6 +21,7 @@ Common API for all public keys.
"""
import base64
+from binascii import hexlify, unhexlify
import os
from Crypto.Hash import MD5
@@ -139,8 +140,6 @@ class PKey (object):
@return: a base64 string containing the public part of the key.
@rtype: str
-
- @since: fearow
"""
return base64.encodestring(str(self)).replace('\n', '')
@@ -173,7 +172,7 @@ class PKey (object):
"""
return False
- def from_private_key_file(cl, filename, password=None):
+ def from_private_key_file(cls, filename, password=None):
"""
Create a key object by reading a private key file. If the private
key is encrypted and C{password} is not C{None}, the given password
@@ -182,41 +181,76 @@ class PKey (object):
exist in all subclasses of PKey (such as L{RSAKey} or L{DSSKey}), but
is useless on the abstract PKey class.
- @param filename: name of the file to read.
+ @param filename: name of the file to read
@type filename: str
@param password: an optional password to use to decrypt the key file,
if it's encrypted
@type password: str
- @return: a new key object based on the given private key.
+ @return: a new key object based on the given private key
@rtype: L{PKey}
- @raise IOError: if there was an error reading the file.
+ @raise IOError: if there was an error reading the file
@raise PasswordRequiredException: if the private key file is
- encrypted, and C{password} is C{None}.
- @raise SSHException: if the key file is invalid.
-
- @since: fearow
+ encrypted, and C{password} is C{None}
+ @raise SSHException: if the key file is invalid
"""
- key = cl(filename=filename, password=password)
+ key = cls(filename=filename, password=password)
return key
from_private_key_file = classmethod(from_private_key_file)
+ def from_private_key(cls, file_obj, password=None):
+ """
+ Create a key object by reading a private key from a file (or file-like)
+ object. If the private key is encrypted and C{password} is not C{None},
+ the given password will be used to decrypt the key (otherwise
+ L{PasswordRequiredException} is thrown).
+
+ @param file_obj: the file to read from
+ @type file_obj: file
+ @param password: an optional password to use to decrypt the key, if it's
+ encrypted
+ @type password: str
+ @return: a new key object based on the given private key
+ @rtype: L{PKey}
+
+ @raise IOError: if there was an error reading the key
+ @raise PasswordRequiredException: if the private key file is encrypted,
+ and C{password} is C{None}
+ @raise SSHException: if the key file is invalid
+ """
+ key = cls(file_obj=file_obj, password=password)
+ return key
+ from_private_key = classmethod(from_private_key)
+
def write_private_key_file(self, filename, password=None):
"""
Write private key contents into a file. If the password is not
C{None}, the key is encrypted before writing.
- @param filename: name of the file to write.
+ @param filename: name of the file to write
@type filename: str
- @param password: an optional password to use to encrypt the key file.
+ @param password: an optional password to use to encrypt the key file
@type password: str
- @raise IOError: if there was an error writing the file.
- @raise SSHException: if the key is invalid.
-
- @since: fearow
+ @raise IOError: if there was an error writing the file
+ @raise SSHException: if the key is invalid
+ """
+ raise Exception('Not implemented in PKey')
+
+ def write_private_key(self, file_obj, password=None):
"""
- raise exception('Not implemented in PKey')
+ Write private key contents into a file (or file-like) object. If the
+ password is not C{None}, the key is encrypted before writing.
+
+ @param file_obj: the file object to write into
+ @type file_obj: file
+ @param password: an optional password to use to encrypt the key
+ @type password: str
+
+ @raise IOError: if there was an error writing to the file
+ @raise SSHException: if the key is invalid
+ """
+ raise Exception('Not implemented in PKey')
def _read_private_key_file(self, tag, filename, password=None):
"""
@@ -242,8 +276,12 @@ class PKey (object):
@raise SSHException: if the key file is invalid.
"""
f = open(filename, 'r')
- lines = f.readlines()
+ data = self._read_private_key(tag, f, password)
f.close()
+ return data
+
+ def _read_private_key(self, tag, f, password=None):
+ lines = f.readlines()
start = 0
while (start < len(lines)) and (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----'):
start += 1
@@ -265,9 +303,9 @@ class PKey (object):
# if we trudged to the end of the file, just try to cope.
try:
data = base64.decodestring(''.join(lines[start:end]))
- except binascii.Error, e:
+ except base64.binascii.Error, e:
raise SSHException('base64 decoding error: ' + str(e))
- if not headers.has_key('proc-type'):
+ if 'proc-type' not in headers:
# unencryped: done
return data
# encrypted keyfile: will need a password
@@ -277,7 +315,7 @@ class PKey (object):
encryption_type, saltstr = headers['dek-info'].split(',')
except:
raise SSHException('Can\'t parse DEK-info in private key file')
- if not self._CIPHER_TABLE.has_key(encryption_type):
+ if encryption_type not in self._CIPHER_TABLE:
raise SSHException('Unknown private key cipher "%s"' % encryption_type)
# if no password was passed in, raise an exception pointing out that we need one
if password is None:
@@ -285,7 +323,7 @@ class PKey (object):
cipher = self._CIPHER_TABLE[encryption_type]['cipher']
keysize = self._CIPHER_TABLE[encryption_type]['keysize']
mode = self._CIPHER_TABLE[encryption_type]['mode']
- salt = util.unhexify(saltstr)
+ salt = unhexlify(saltstr)
key = util.generate_key_bytes(MD5, salt, password, keysize)
return cipher.new(key, mode, salt).decrypt(data)
@@ -310,6 +348,10 @@ class PKey (object):
f = open(filename, 'w', 0600)
# grrr... the mode doesn't always take hold
os.chmod(filename, 0600)
+ self._write_private_key(tag, f, data, password)
+ f.close()
+
+ def _write_private_key(self, tag, f, data, password=None):
f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag)
if password is not None:
# since we only support one cipher here, use it
@@ -327,7 +369,7 @@ class PKey (object):
data += '\0' * n
data = cipher.new(key, mode, salt).encrypt(data)
f.write('Proc-Type: 4,ENCRYPTED\n')
- f.write('DEK-Info: %s,%s\n' % (cipher_name, util.hexify(salt)))
+ f.write('DEK-Info: %s,%s\n' % (cipher_name, hexlify(salt).upper()))
f.write('\n')
s = base64.encodestring(data)
# re-wrap to 64-char lines
@@ -336,4 +378,3 @@ class PKey (object):
f.write(s)
f.write('\n')
f.write('-----END %s PRIVATE KEY-----\n' % tag)
- f.close()
diff --git a/paramiko/primes.py b/paramiko/primes.py
index 3677394..7b35736 100644
--- a/paramiko/primes.py
+++ b/paramiko/primes.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -23,6 +23,7 @@ Utility functions for dealing with primes.
from Crypto.Util import number
from paramiko import util
+from paramiko.ssh_exception import SSHException
def _generate_prime(bits, randpool):
@@ -39,7 +40,8 @@ def _generate_prime(bits, randpool):
while not number.isPrime(n):
n += 2
if util.bit_length(n) == bits:
- return n
+ break
+ return n
def _roll_random(rpool, n):
"returns a random # from 0 to N-1"
@@ -59,7 +61,8 @@ def _roll_random(rpool, n):
x = chr(ord(x[0]) & hbyte_mask) + x[1:]
num = util.inflate_long(x, 1)
if num < n:
- return num
+ break
+ return num
class ModulusPack (object):
@@ -75,8 +78,8 @@ class ModulusPack (object):
self.randpool = rpool
def _parse_modulus(self, line):
- timestamp, type, tests, tries, size, generator, modulus = line.split()
- type = int(type)
+ timestamp, mod_type, tests, tries, size, generator, modulus = line.split()
+ mod_type = int(mod_type)
tests = int(tests)
tries = int(tries)
size = int(size)
@@ -87,7 +90,7 @@ class ModulusPack (object):
# type 2 (meets basic structural requirements)
# test 4 (more than just a small-prime sieve)
# tries < 100 if test & 4 (at least 100 tries of miller-rabin)
- if (type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)):
+ if (mod_type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)):
self.discarded.append((modulus, 'does not meet basic requirements'))
return
if generator == 0:
@@ -100,7 +103,7 @@ class ModulusPack (object):
if (bl != size) and (bl != size + 1):
self.discarded.append((modulus, 'incorrectly reported bit length %d' % size))
return
- if not self.pack.has_key(bl):
+ if bl not in self.pack:
self.pack[bl] = []
self.pack[bl].append((generator, modulus))
diff --git a/paramiko/resource.py b/paramiko/resource.py
new file mode 100644
index 0000000..a089754
--- /dev/null
+++ b/paramiko/resource.py
@@ -0,0 +1,72 @@
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+
+"""
+Resource manager.
+"""
+
+import weakref
+
+
+class ResourceManager (object):
+ """
+ A registry of objects and resources that should be closed when those
+ objects are deleted.
+
+ This is meant to be a safer alternative to python's C{__del__} method,
+ which can cause reference cycles to never be collected. Objects registered
+ with the ResourceManager can be collected but still free resources when
+ they die.
+
+ Resources are registered using L{register}, and when an object is garbage
+ collected, each registered resource is closed by having its C{close()}
+ method called. Multiple resources may be registered per object, but a
+ resource will only be closed once, even if multiple objects register it.
+ (The last object to register it wins.)
+ """
+
+ def __init__(self):
+ self._table = {}
+
+ def register(self, obj, resource):
+ """
+ Register a resource to be closed with an object is collected.
+
+ When the given C{obj} is garbage-collected by the python interpreter,
+ the C{resource} will be closed by having its C{close()} method called.
+ Any exceptions are ignored.
+
+ @param obj: the object to track
+ @type obj: object
+ @param resource: the resource to close when the object is collected
+ @type resource: object
+ """
+ def callback(ref):
+ try:
+ resource.close()
+ except:
+ pass
+ del self._table[id(resource)]
+
+ # keep the weakref in a table so it sticks around long enough to get
+ # its callback called. :)
+ self._table[id(resource)] = weakref.ref(obj, callback)
+
+
+# singleton
+ResourceManager = ResourceManager()
diff --git a/paramiko/rng.py b/paramiko/rng.py
new file mode 100644
index 0000000..46329d1
--- /dev/null
+++ b/paramiko/rng.py
@@ -0,0 +1,112 @@
+#!/usr/bin/python
+# -*- coding: ascii -*-
+# Copyright (C) 2008 Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+import sys
+import threading
+from Crypto.Util.randpool import RandomPool as _RandomPool
+
+try:
+ import platform
+except ImportError:
+ platform = None # Not available using Python 2.2
+
+def _strxor(a, b):
+ assert len(a) == len(b)
+ return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), a, b))
+
+##
+## Find a strong random entropy source, depending on the detected platform.
+## WARNING TO DEVELOPERS: This will fail on some systems, but do NOT use
+## Crypto.Util.randpool.RandomPool as a fall-back. RandomPool will happily run
+## with very little entropy, thus _silently_ defeating any security that
+## Paramiko attempts to provide. (This is current as of PyCrypto 2.0.1).
+## See http://www.lag.net/pipermail/paramiko/2008-January/000599.html
+## and http://www.lag.net/pipermail/paramiko/2008-April/000678.html
+##
+
+if ((platform is not None and platform.system().lower() == 'windows') or
+ sys.platform == 'win32'):
+ # MS Windows
+ from paramiko import rng_win32
+ rng_device = rng_win32.open_rng_device()
+else:
+ # Assume POSIX (any system where /dev/urandom exists)
+ from paramiko import rng_posix
+ rng_device = rng_posix.open_rng_device()
+
+
+class StrongLockingRandomPool(object):
+ """Wrapper around RandomPool guaranteeing strong random numbers.
+
+ Crypto.Util.randpool.RandomPool will silently operate even if it is seeded
+ with little or no entropy, and it provides no prediction resistance if its
+ state is ever compromised throughout its runtime. It is also not thread-safe.
+
+ This wrapper augments RandomPool by XORing its output with random bits from
+ the operating system, and by controlling access to the underlying
+ RandomPool using an exclusive lock.
+ """
+
+ def __init__(self, instance=None):
+ if instance is None:
+ instance = _RandomPool()
+ self.randpool = instance
+ self.randpool_lock = threading.Lock()
+ self.entropy = rng_device
+
+ # Stir 256 bits of entropy from the RNG device into the RandomPool.
+ self.randpool.stir(self.entropy.read(32))
+ self.entropy.randomize()
+
+ def stir(self, s=''):
+ self.randpool_lock.acquire()
+ try:
+ self.randpool.stir(s)
+ finally:
+ self.randpool_lock.release()
+ self.entropy.randomize()
+
+ def randomize(self, N=0):
+ self.randpool_lock.acquire()
+ try:
+ self.randpool.randomize(N)
+ finally:
+ self.randpool_lock.release()
+ self.entropy.randomize()
+
+ def add_event(self, s=''):
+ self.randpool_lock.acquire()
+ try:
+ self.randpool.add_event(s)
+ finally:
+ self.randpool_lock.release()
+
+ def get_bytes(self, N):
+ self.randpool_lock.acquire()
+ try:
+ randpool_data = self.randpool.get_bytes(N)
+ finally:
+ self.randpool_lock.release()
+ entropy_data = self.entropy.read(N)
+ result = _strxor(randpool_data, entropy_data)
+ assert len(randpool_data) == N and len(entropy_data) == N and len(result) == N
+ return result
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/paramiko/rng_posix.py b/paramiko/rng_posix.py
new file mode 100644
index 0000000..1e6d72c
--- /dev/null
+++ b/paramiko/rng_posix.py
@@ -0,0 +1,97 @@
+#!/usr/bin/python
+# -*- coding: ascii -*-
+# Copyright (C) 2008 Dwayne C. Litzenberger <dlitz@dlitz.net>
+# Copyright (C) 2008 Open Systems Canada Limited
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+import os
+import stat
+
+class error(Exception):
+ pass
+
+class _RNG(object):
+ def __init__(self, file):
+ self.file = file
+
+ def read(self, bytes):
+ return self.file.read(bytes)
+
+ def close(self):
+ return self.file.close()
+
+ def randomize(self):
+ return
+
+def open_rng_device(device_path=None):
+ """Open /dev/urandom and perform some sanity checks."""
+
+ f = None
+ g = None
+
+ if device_path is None:
+ device_path = "/dev/urandom"
+
+ try:
+ # Try to open /dev/urandom now so that paramiko will be able to access
+ # it even if os.chroot() is invoked later.
+ try:
+ f = open(device_path, "rb", 0)
+ except EnvironmentError:
+ raise error("Unable to open /dev/urandom")
+
+ # Open a second file descriptor for sanity checking later.
+ try:
+ g = open(device_path, "rb", 0)
+ except EnvironmentError:
+ raise error("Unable to open /dev/urandom")
+
+ # Check that /dev/urandom is a character special device, not a regular file.
+ st = os.fstat(f.fileno()) # f
+ if stat.S_ISREG(st.st_mode) or not stat.S_ISCHR(st.st_mode):
+ raise error("/dev/urandom is not a character special device")
+
+ st = os.fstat(g.fileno()) # g
+ if stat.S_ISREG(st.st_mode) or not stat.S_ISCHR(st.st_mode):
+ raise error("/dev/urandom is not a character special device")
+
+ # Check that /dev/urandom always returns the number of bytes requested
+ x = f.read(20)
+ y = g.read(20)
+ if len(x) != 20 or len(y) != 20:
+ raise error("Error reading from /dev/urandom: input truncated")
+
+ # Check that different reads return different data
+ if x == y:
+ raise error("/dev/urandom is broken; returning identical data: %r == %r" % (x, y))
+
+ # Close the duplicate file object
+ g.close()
+
+ # Return the first file object
+ return _RNG(f)
+
+ except error:
+ if f is not None:
+ f.close()
+ if g is not None:
+ g.close()
+ raise
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
+
diff --git a/paramiko/rng_win32.py b/paramiko/rng_win32.py
new file mode 100644
index 0000000..3cb8b84
--- /dev/null
+++ b/paramiko/rng_win32.py
@@ -0,0 +1,121 @@
+#!/usr/bin/python
+# -*- coding: ascii -*-
+# Copyright (C) 2008 Dwayne C. Litzenberger <dlitz@dlitz.net>
+# Copyright (C) 2008 Open Systems Canada Limited
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+class error(Exception):
+ pass
+
+# Try to import the "winrandom" module
+try:
+ from Crypto.Util import winrandom as _winrandom
+except ImportError:
+ _winrandom = None
+
+# Try to import the "urandom" module
+try:
+ from os import urandom as _urandom
+except ImportError:
+ _urandom = None
+
+
+class _RNG(object):
+ def __init__(self, readfunc):
+ self.read = readfunc
+
+ def randomize(self):
+ # According to "Cryptanalysis of the Random Number Generator of the
+ # Windows Operating System", by Leo Dorrendorf and Zvi Gutterman
+ # and Benny Pinkas <http://eprint.iacr.org/2007/419>,
+ # CryptGenRandom only updates its internal state using kernel-provided
+ # random data every 128KiB of output.
+ self.read(128*1024) # discard 128 KiB of output
+
+def _open_winrandom():
+ if _winrandom is None:
+ raise error("Crypto.Util.winrandom module not found")
+
+ # Check that we can open the winrandom module
+ try:
+ r0 = _winrandom.new()
+ r1 = _winrandom.new()
+ except Exception, exc:
+ raise error("winrandom.new() failed: %s" % str(exc), exc)
+
+ # Check that we can read from the winrandom module
+ try:
+ x = r0.get_bytes(20)
+ y = r1.get_bytes(20)
+ except Exception, exc:
+ raise error("winrandom get_bytes failed: %s" % str(exc), exc)
+
+ # Check that the requested number of bytes are returned
+ if len(x) != 20 or len(y) != 20:
+ raise error("Error reading from winrandom: input truncated")
+
+ # Check that different reads return different data
+ if x == y:
+ raise error("winrandom broken: returning identical data")
+
+ return _RNG(r0.get_bytes)
+
+def _open_urandom():
+ if _urandom is None:
+ raise error("os.urandom function not found")
+
+ # Check that we can read from os.urandom()
+ try:
+ x = _urandom(20)
+ y = _urandom(20)
+ except Exception, exc:
+ raise error("os.urandom failed: %s" % str(exc), exc)
+
+ # Check that the requested number of bytes are returned
+ if len(x) != 20 or len(y) != 20:
+ raise error("os.urandom failed: input truncated")
+
+ # Check that different reads return different data
+ if x == y:
+ raise error("os.urandom failed: returning identical data")
+
+ return _RNG(_urandom)
+
+def open_rng_device():
+ # Try using the Crypto.Util.winrandom module
+ try:
+ return _open_winrandom()
+ except error:
+ pass
+
+ # Several versions of PyCrypto do not contain the winrandom module, but
+ # Python >= 2.4 has os.urandom, so try to use that.
+ try:
+ return _open_urandom()
+ except error:
+ pass
+
+ # SECURITY NOTE: DO NOT USE Crypto.Util.randpool.RandomPool HERE!
+ # If we got to this point, RandomPool will silently run with very little
+ # entropy. (This is current as of PyCrypto 2.0.1).
+ # See http://www.lag.net/pipermail/paramiko/2008-January/000599.html
+ # and http://www.lag.net/pipermail/paramiko/2008-April/000678.html
+
+ raise error("Unable to find a strong random entropy source. You cannot run this software securely under the current configuration.")
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py
index 780ea1b..d72d175 100644
--- a/paramiko/rsakey.py
+++ b/paramiko/rsakey.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -38,7 +38,15 @@ class RSAKey (PKey):
data.
"""
- def __init__(self, msg=None, data=None, filename=None, password=None, vals=None):
+ def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None):
+ self.n = None
+ self.e = None
+ self.d = None
+ self.p = None
+ self.q = None
+ if file_obj is not None:
+ self._from_private_key(file_obj, password)
+ return
if filename is not None:
self._from_private_key_file(filename, password)
return
@@ -75,7 +83,7 @@ class RSAKey (PKey):
return self.size
def can_sign(self):
- return hasattr(self, 'd')
+ return self.d is not None
def sign_ssh_data(self, rpool, data):
digest = SHA.new(data).digest()
@@ -93,11 +101,13 @@ class RSAKey (PKey):
# verify the signature by SHA'ing the data and encrypting it using the
# public key. some wackiness ensues where we "pkcs1imify" the 20-byte
# hash into a string as long as the RSA key.
- hash = util.inflate_long(self._pkcs1imify(SHA.new(data).digest()), True)
+ hash_obj = util.inflate_long(self._pkcs1imify(SHA.new(data).digest()), True)
rsa = RSA.construct((long(self.n), long(self.e)))
- return rsa.verify(hash, (sig,))
+ return rsa.verify(hash_obj, (sig,))
- def write_private_key_file(self, filename, password=None):
+ def _encode_key(self):
+ if (self.p is None) or (self.q is None):
+ raise SSHException('Not enough key info to write private key file')
keylist = [ 0, self.n, self.e, self.d, self.p, self.q,
self.d % (self.p - 1), self.d % (self.q - 1),
util.mod_inverse(self.q, self.p) ]
@@ -106,7 +116,13 @@ class RSAKey (PKey):
b.encode(keylist)
except BERException:
raise SSHException('Unable to create ber encoding of key')
- self._write_private_key_file('RSA', filename, str(b), password)
+ return str(b)
+
+ def write_private_key_file(self, filename, password=None):
+ self._write_private_key_file('RSA', filename, self._encode_key(), password)
+
+ def write_private_key(self, file_obj, password=None):
+ self._write_private_key('RSA', file_obj, self._encode_key(), password)
def generate(bits, progress_func=None):
"""
@@ -120,8 +136,6 @@ class RSAKey (PKey):
@type progress_func: function
@return: new private key
@rtype: L{RSAKey}
-
- @since: fearow
"""
randpool.stir()
rsa = RSA.generate(bits, randpool.get_bytes, progress_func)
@@ -147,9 +161,16 @@ class RSAKey (PKey):
return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data
def _from_private_key_file(self, filename, password):
+ data = self._read_private_key_file('RSA', filename, password)
+ self._decode_key(data)
+
+ def _from_private_key(self, file_obj, password):
+ data = self._read_private_key('RSA', file_obj, password)
+ self._decode_key(data)
+
+ def _decode_key(self, data):
# private key file contains:
# RSAPrivateKey = { version = 0, n, e, d, p, q, d mod p-1, d mod q-1, q**-1 mod p }
- data = self._read_private_key_file('RSA', filename, password)
try:
keylist = BER(data).decode()
except BERException:
diff --git a/paramiko/server.py b/paramiko/server.py
index a0e3988..bcaa4be 100644
--- a/paramiko/server.py
+++ b/paramiko/server.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -41,6 +41,8 @@ class InteractiveQuery (object):
@type name: str
@param instructions: user instructions (usually short) about this query
@type instructions: str
+ @param prompts: one or more authentication prompts
+ @type prompts: str
"""
self.name = name
self.instructions = instructions
@@ -90,6 +92,7 @@ class ServerInterface (object):
- L{check_channel_shell_request}
- L{check_channel_subsystem_request}
- L{check_channel_window_change_request}
+ - L{check_channel_x11_request}
The C{chanid} parameter is a small number that uniquely identifies the
channel within a L{Transport}. A L{Channel} object is not created
@@ -273,6 +276,42 @@ class ServerInterface (object):
"""
return AUTH_FAILED
+ def check_port_forward_request(self, address, port):
+ """
+ Handle a request for port forwarding. The client is asking that
+ connections to the given address and port be forwarded back across
+ this ssh connection. An address of C{"0.0.0.0"} indicates a global
+ address (any address associated with this server) and a port of C{0}
+ indicates that no specific port is requested (usually the OS will pick
+ a port).
+
+ The default implementation always returns C{False}, rejecting the
+ port forwarding request. If the request is accepted, you should return
+ the port opened for listening.
+
+ @param address: the requested address
+ @type address: str
+ @param port: the requested port
+ @type port: int
+ @return: the port number that was opened for listening, or C{False} to
+ reject
+ @rtype: int
+ """
+ return False
+
+ def cancel_port_forward_request(self, address, port):
+ """
+ The client would like to cancel a previous port-forwarding request.
+ If the given address and port is being forwarded across this ssh
+ connection, the port should be closed.
+
+ @param address: the forwarded address
+ @type address: str
+ @param port: the forwarded port
+ @type port: int
+ """
+ pass
+
def check_global_request(self, kind, msg):
"""
Handle a global request of the given C{kind}. This method is called
@@ -291,6 +330,9 @@ class ServerInterface (object):
The default implementation always returns C{False}, indicating that it
does not support any global requests.
+
+ @note: Port forwarding requests are handled separately, in
+ L{check_port_forward_request}.
@param kind: the kind of global request being made.
@type kind: str
@@ -426,6 +468,71 @@ class ServerInterface (object):
@rtype: bool
"""
return False
+
+ def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number):
+ """
+ Determine if the client will be provided with an X11 session. If this
+ method returns C{True}, X11 applications should be routed through new
+ SSH channels, using L{Transport.open_x11_channel}.
+
+ The default implementation always returns C{False}.
+
+ @param channel: the L{Channel} the X11 request arrived on
+ @type channel: L{Channel}
+ @param single_connection: C{True} if only a single X11 channel should
+ be opened
+ @type single_connection: bool
+ @param auth_protocol: the protocol used for X11 authentication
+ @type auth_protocol: str
+ @param auth_cookie: the cookie used to authenticate to X11
+ @type auth_cookie: str
+ @param screen_number: the number of the X11 screen to connect to
+ @type screen_number: int
+ @return: C{True} if the X11 session was opened; C{False} if not
+ @rtype: bool
+ """
+ return False
+
+ def check_channel_direct_tcpip_request(self, chanid, origin, destination):
+ """
+ Determine if a local port forwarding channel will be granted, and
+ return C{OPEN_SUCCEEDED} or an error code. This method is
+ called in server mode when the client requests a channel, after
+ authentication is complete.
+
+ The C{chanid} parameter is a small number that uniquely identifies the
+ channel within a L{Transport}. A L{Channel} object is not created
+ unless this method returns C{OPEN_SUCCEEDED} -- once a
+ L{Channel} object is created, you can call L{Channel.get_id} to
+ retrieve the channel ID.
+
+ The origin and destination parameters are (ip_address, port) tuples
+ that correspond to both ends of the TCP connection in the forwarding
+ tunnel.
+
+ The return value should either be C{OPEN_SUCCEEDED} (or
+ C{0}) to allow the channel request, or one of the following error
+ codes to reject it:
+ - C{OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED}
+ - C{OPEN_FAILED_CONNECT_FAILED}
+ - C{OPEN_FAILED_UNKNOWN_CHANNEL_TYPE}
+ - C{OPEN_FAILED_RESOURCE_SHORTAGE}
+
+ The default implementation always returns
+ C{OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED}.
+
+ @param chanid: ID of the channel
+ @type chanid: int
+ @param origin: 2-tuple containing the IP address and port of the
+ originator (client side)
+ @type origin: tuple
+ @param destination: 2-tuple containing the IP address and port of the
+ destination (server side)
+ @type destination: tuple
+ @return: a success or failure code (listed above)
+ @rtype: int
+ """
+ return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
class SubsystemHandler (threading.Thread):
@@ -443,8 +550,6 @@ class SubsystemHandler (threading.Thread):
authenticated and requests subsytem C{"mp3"}, an object of class
C{MP3Handler} will be created, and L{start_subsystem} will be called on
it from a new thread.
-
- @since: ivysaur
"""
def __init__(self, channel, name, server):
"""
diff --git a/paramiko/sftp.py b/paramiko/sftp.py
index 58d7103..2296d85 100644
--- a/paramiko/sftp.py
+++ b/paramiko/sftp.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -16,6 +16,7 @@
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+import select
import socket
import struct
@@ -113,24 +114,22 @@ class BaseSFTP (object):
return version
def _send_server_version(self):
+ # winscp will freak out if the server sends version info before the
+ # client finishes sending INIT.
+ t, data = self._read_packet()
+ if t != CMD_INIT:
+ raise SFTPError('Incompatible sftp protocol')
+ version = struct.unpack('>I', data[:4])[0]
# advertise that we support "check-file"
extension_pairs = [ 'check-file', 'md5,sha1' ]
msg = Message()
msg.add_int(_VERSION)
msg.add(*extension_pairs)
self._send_packet(CMD_VERSION, str(msg))
- t, data = self._read_packet()
- if t != CMD_INIT:
- raise SFTPError('Incompatible sftp protocol')
- version = struct.unpack('>I', data[:4])[0]
return version
- def _log(self, level, msg):
- if issubclass(type(msg), list):
- for m in msg:
- self.logger.log(level, m)
- else:
- self.logger.log(level, msg)
+ def _log(self, level, msg, *args):
+ self.logger.log(level, msg, *args)
def _write_all(self, out):
while len(out) > 0:
@@ -145,7 +144,20 @@ class BaseSFTP (object):
def _read_all(self, n):
out = ''
while n > 0:
- x = self.sock.recv(n)
+ if isinstance(self.sock, socket.socket):
+ # sometimes sftp is used directly over a socket instead of
+ # through a paramiko channel. in this case, check periodically
+ # if the socket is closed. (for some reason, recv() won't ever
+ # return or raise an exception, but calling select on a closed
+ # socket will.)
+ while True:
+ read, write, err = select.select([ self.sock ], [], [], 0.1)
+ if len(read) > 0:
+ x = self.sock.recv(n)
+ break
+ else:
+ x = self.sock.recv(n)
+
if len(x) == 0:
raise EOFError()
out += x
@@ -153,16 +165,24 @@ class BaseSFTP (object):
return out
def _send_packet(self, t, packet):
+ #self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet)))
out = struct.pack('>I', len(packet) + 1) + chr(t) + packet
if self.ultra_debug:
self._log(DEBUG, util.format_binary(out, 'OUT: '))
self._write_all(out)
def _read_packet(self):
- size = struct.unpack('>I', self._read_all(4))[0]
+ x = self._read_all(4)
+ # most sftp servers won't accept packets larger than about 32k, so
+ # anything with the high byte set (> 16MB) is just garbage.
+ if x[0] != '\x00':
+ raise SFTPError('Garbage packet received')
+ size = struct.unpack('>I', x)[0]
data = self._read_all(size)
if self.ultra_debug:
self._log(DEBUG, util.format_binary(data, 'IN: '));
if size > 0:
- return ord(data[0]), data[1:]
+ t = ord(data[0])
+ #self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1))
+ return t, data[1:]
return 0, ''
diff --git a/paramiko/sftp_attr.py b/paramiko/sftp_attr.py
index eae7c99..9c92862 100644
--- a/paramiko/sftp_attr.py
+++ b/paramiko/sftp_attr.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2006 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -51,6 +51,12 @@ class SFTPAttributes (object):
Create a new (empty) SFTPAttributes object. All fields will be empty.
"""
self._flags = 0
+ self.st_size = None
+ self.st_uid = None
+ self.st_gid = None
+ self.st_mode = None
+ self.st_atime = None
+ self.st_mtime = None
self.attr = {}
def from_stat(cls, obj, filename=None):
@@ -80,18 +86,17 @@ class SFTPAttributes (object):
def __repr__(self):
return '<SFTPAttributes: %s>' % self._debug_str()
- def __str__(self):
- return self._debug_str()
-
### internals...
- def _from_msg(cls, msg, filename=None):
+ def _from_msg(cls, msg, filename=None, longname=None):
attr = cls()
attr._unpack(msg)
if filename is not None:
attr.filename = filename
+ if longname is not None:
+ attr.longname = longname
return attr
_from_msg = classmethod(_from_msg)
@@ -114,13 +119,13 @@ class SFTPAttributes (object):
def _pack(self, msg):
self._flags = 0
- if hasattr(self, 'st_size'):
+ if self.st_size is not None:
self._flags |= self.FLAG_SIZE
- if hasattr(self, 'st_uid') or hasattr(self, 'st_gid'):
+ if (self.st_uid is not None) and (self.st_gid is not None):
self._flags |= self.FLAG_UIDGID
- if hasattr(self, 'st_mode'):
+ if self.st_mode is not None:
self._flags |= self.FLAG_PERMISSIONS
- if hasattr(self, 'st_atime') or hasattr(self, 'st_mtime'):
+ if (self.st_atime is not None) and (self.st_mtime is not None):
self._flags |= self.FLAG_AMTIME
if len(self.attr) > 0:
self._flags |= self.FLAG_EXTENDED
@@ -128,13 +133,14 @@ class SFTPAttributes (object):
if self._flags & self.FLAG_SIZE:
msg.add_int64(self.st_size)
if self._flags & self.FLAG_UIDGID:
- msg.add_int(getattr(self, 'st_uid', 0))
- msg.add_int(getattr(self, 'st_gid', 0))
+ msg.add_int(self.st_uid)
+ msg.add_int(self.st_gid)
if self._flags & self.FLAG_PERMISSIONS:
msg.add_int(self.st_mode)
if self._flags & self.FLAG_AMTIME:
- msg.add_int(getattr(self, 'st_atime', 0))
- msg.add_int(getattr(self, 'st_mtime', 0))
+ # throw away any fractional seconds
+ msg.add_int(long(self.st_atime))
+ msg.add_int(long(self.st_mtime))
if self._flags & self.FLAG_EXTENDED:
msg.add_int(len(self.attr))
for key, val in self.attr.iteritems():
@@ -144,15 +150,14 @@ class SFTPAttributes (object):
def _debug_str(self):
out = '[ '
- if hasattr(self, 'st_size'):
+ if self.st_size is not None:
out += 'size=%d ' % self.st_size
- if hasattr(self, 'st_uid') or hasattr(self, 'st_gid'):
- out += 'uid=%d gid=%d ' % (getattr(self, 'st_uid', 0), getattr(self, 'st_gid', 0))
- if hasattr(self, 'st_mode'):
+ if (self.st_uid is not None) and (self.st_gid is not None):
+ out += 'uid=%d gid=%d ' % (self.st_uid, self.st_gid)
+ if self.st_mode is not None:
out += 'mode=' + oct(self.st_mode) + ' '
- if hasattr(self, 'st_atime') or hasattr(self, 'st_mtime'):
- out += 'atime=%d mtime=%d ' % (getattr(self, 'st_atime', 0),
- getattr(self, 'st_mtime', 0))
+ if (self.st_atime is not None) and (self.st_mtime is not None):
+ out += 'atime=%d mtime=%d ' % (self.st_atime, self.st_mtime)
for k, v in self.attr.iteritems():
out += '"%s"=%r ' % (str(k), v)
out += ']'
@@ -171,7 +176,7 @@ class SFTPAttributes (object):
def __str__(self):
"create a unix-style long description of the file (like ls -l)"
- if hasattr(self, 'st_mode'):
+ if self.st_mode is not None:
kind = stat.S_IFMT(self.st_mode)
if kind == stat.S_IFIFO:
ks = 'p'
@@ -194,15 +199,25 @@ class SFTPAttributes (object):
ks += self._rwx(self.st_mode & 7, self.st_mode & stat.S_ISVTX, True)
else:
ks = '?---------'
- uid = getattr(self, 'st_uid', -1)
- gid = getattr(self, 'st_gid', -1)
- size = getattr(self, 'st_size', -1)
- mtime = getattr(self, 'st_mtime', 0)
# compute display date
- if abs(time.time() - mtime) > 15552000:
- # (15552000 = 6 months)
- datestr = time.strftime('%d %b %Y', time.localtime(mtime))
+ if (self.st_mtime is None) or (self.st_mtime == 0xffffffff):
+ # shouldn't really happen
+ datestr = '(unknown date)'
else:
- datestr = time.strftime('%d %b %H:%M', time.localtime(mtime))
+ if abs(time.time() - self.st_mtime) > 15552000:
+ # (15552000 = 6 months)
+ datestr = time.strftime('%d %b %Y', time.localtime(self.st_mtime))
+ else:
+ datestr = time.strftime('%d %b %H:%M', time.localtime(self.st_mtime))
filename = getattr(self, 'filename', '?')
- return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, size, datestr, filename)
+
+ # not all servers support uid/gid
+ uid = self.st_uid
+ gid = self.st_gid
+ if uid is None:
+ uid = 0
+ if gid is None:
+ gid = 0
+
+ return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, self.st_size, datestr, filename)
+
diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py
index 2fe89e9..b3d2d56 100644
--- a/paramiko/sftp_client.py
+++ b/paramiko/sftp_client.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -20,21 +20,32 @@
Client-mode SFTP support.
"""
+from binascii import hexlify
import errno
import os
import threading
+import time
import weakref
+
from paramiko.sftp import *
from paramiko.sftp_attr import SFTPAttributes
+from paramiko.ssh_exception import SSHException
from paramiko.sftp_file import SFTPFile
def _to_unicode(s):
- "if a str is not ascii, decode its utf8 into unicode"
+ """
+ decode a string as ascii or utf8 if possible (as required by the sftp
+ protocol). if neither works, just return a byte string because the server
+ probably doesn't know the filename's encoding.
+ """
try:
return s.encode('ascii')
- except:
- return s.decode('utf-8')
+ except UnicodeError:
+ try:
+ return s.decode('utf-8')
+ except UnicodeError:
+ return s
class SFTPClient (BaseSFTP):
@@ -51,8 +62,11 @@ class SFTPClient (BaseSFTP):
An alternate way to create an SFTP client context is by using
L{from_transport}.
- @param sock: an open L{Channel} using the C{"sftp"} subsystem.
+ @param sock: an open L{Channel} using the C{"sftp"} subsystem
@type sock: L{Channel}
+
+ @raise SSHException: if there's an exception while negotiating
+ sftp
"""
BaseSFTP.__init__(self)
self.sock = sock
@@ -66,31 +80,33 @@ class SFTPClient (BaseSFTP):
if type(sock) is Channel:
# override default logger
transport = self.sock.get_transport()
- self.logger = util.get_logger(transport.get_log_channel() + '.' +
- self.sock.get_name() + '.sftp')
+ self.logger = util.get_logger(transport.get_log_channel() + '.sftp')
self.ultra_debug = transport.get_hexdump()
- self._send_version()
-
- def __del__(self):
- self.close()
+ try:
+ server_version = self._send_version()
+ except EOFError, x:
+ raise SSHException('EOF during negotiation')
+ self._log(INFO, 'Opened sftp connection (server version %d)' % server_version)
- def from_transport(selfclass, t):
+ def from_transport(cls, t):
"""
Create an SFTP client channel from an open L{Transport}.
- @param t: an open L{Transport} which is already authenticated.
+ @param t: an open L{Transport} which is already authenticated
@type t: L{Transport}
@return: a new L{SFTPClient} object, referring to an sftp session
- (channel) across the transport.
+ (channel) across the transport
@rtype: L{SFTPClient}
"""
chan = t.open_session()
if chan is None:
return None
- if not chan.invoke_subsystem('sftp'):
- raise SFTPError('Failed to invoke sftp subsystem')
- return selfclass(chan)
+ chan.invoke_subsystem('sftp')
+ return cls(chan)
from_transport = classmethod(from_transport)
+
+ def _log(self, level, msg, *args):
+ super(SFTPClient, self)._log(level, "[chan %s] " + msg, *([ self.sock.get_name() ] + list(args)))
def close(self):
"""
@@ -98,7 +114,20 @@ class SFTPClient (BaseSFTP):
@since: 1.4
"""
+ self._log(INFO, 'sftp session closed.')
self.sock.close()
+
+ def get_channel(self):
+ """
+ Return the underlying L{Channel} object for this SFTP session. This
+ might be useful for doing things like setting a timeout on the channel.
+
+ @return: the SSH channel
+ @rtype: L{Channel}
+
+ @since: 1.7.1
+ """
+ return self.sock
def listdir(self, path='.'):
"""
@@ -121,6 +150,11 @@ class SFTPClient (BaseSFTP):
files in the given C{path}. The list is in arbitrary order. It does
not include the special entries C{'.'} and C{'..'} even if they are
present in the folder.
+
+ The returned L{SFTPAttributes} objects will each have an additional
+ field: C{longname}, which may contain a formatted string of the file's
+ attributes, in unix format. The content of this string will probably
+ depend on the SFTP server implementation.
@param path: path to list (defaults to C{'.'})
@type path: str
@@ -130,6 +164,7 @@ class SFTPClient (BaseSFTP):
@since: 1.2
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'listdir(%r)' % path)
t, msg = self._request(CMD_OPENDIR, path)
if t != CMD_HANDLE:
raise SFTPError('Expected handle')
@@ -147,13 +182,13 @@ class SFTPClient (BaseSFTP):
for i in range(count):
filename = _to_unicode(msg.get_string())
longname = _to_unicode(msg.get_string())
- attr = SFTPAttributes._from_msg(msg, filename)
+ attr = SFTPAttributes._from_msg(msg, filename, longname)
if (filename != '.') and (filename != '..'):
filelist.append(attr)
self._request(CMD_CLOSE, handle)
return filelist
- def file(self, filename, mode='r', bufsize=-1):
+ def open(self, filename, mode='r', bufsize=-1):
"""
Open a file on the remote server. The arguments are the same as for
python's built-in C{file} (aka C{open}). A file-like object is
@@ -177,18 +212,19 @@ class SFTPClient (BaseSFTP):
buffering, C{1} uses line buffering, and any number greater than 1
(C{>1}) uses that specific buffer size.
- @param filename: name of the file to open.
- @type filename: string
- @param mode: mode (python-style) to open in.
- @type mode: string
+ @param filename: name of the file to open
+ @type filename: str
+ @param mode: mode (python-style) to open in
+ @type mode: str
@param bufsize: desired buffering (-1 = default buffer size)
@type bufsize: int
- @return: a file object representing the open file.
+ @return: a file object representing the open file
@rtype: SFTPFile
@raise IOError: if the file could not be opened.
"""
filename = self._adjust_cwd(filename)
+ self._log(DEBUG, 'open(%r, %r)' % (filename, mode))
imode = 0
if ('r' in mode) or ('+' in mode):
imode |= SFTP_FLAG_READ
@@ -205,23 +241,24 @@ class SFTPClient (BaseSFTP):
if t != CMD_HANDLE:
raise SFTPError('Expected handle')
handle = msg.get_string()
+ self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, hexlify(handle)))
return SFTPFile(self, handle, mode, bufsize)
- # python has migrated toward file() instead of open().
- # and really, that's more easily identifiable.
- open = file
+ # python continues to vacillate about "open" vs "file"...
+ file = open
def remove(self, path):
"""
- Remove the file at the given path.
+ Remove the file at the given path. This only works on files; for
+ removing folders (directories), use L{rmdir}.
- @param path: path (absolute or relative) of the file to remove.
- @type path: string
+ @param path: path (absolute or relative) of the file to remove
+ @type path: str
- @raise IOError: if the path refers to a folder (directory). Use
- L{rmdir} to remove a folder.
+ @raise IOError: if the path refers to a folder (directory)
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'remove(%r)' % path)
self._request(CMD_REMOVE, path)
unlink = remove
@@ -230,16 +267,17 @@ class SFTPClient (BaseSFTP):
"""
Rename a file or folder from C{oldpath} to C{newpath}.
- @param oldpath: existing name of the file or folder.
- @type oldpath: string
- @param newpath: new name for the file or folder.
- @type newpath: string
+ @param oldpath: existing name of the file or folder
+ @type oldpath: str
+ @param newpath: new name for the file or folder
+ @type newpath: str
@raise IOError: if C{newpath} is a folder, or something else goes
- wrong.
+ wrong
"""
oldpath = self._adjust_cwd(oldpath)
newpath = self._adjust_cwd(newpath)
+ self._log(DEBUG, 'rename(%r, %r)' % (oldpath, newpath))
self._request(CMD_RENAME, oldpath, newpath)
def mkdir(self, path, mode=0777):
@@ -248,12 +286,13 @@ class SFTPClient (BaseSFTP):
The default mode is 0777 (octal). On some systems, mode is ignored.
Where it is used, the current umask value is first masked out.
- @param path: name of the folder to create.
- @type path: string
- @param mode: permissions (posix-style) for the newly-created folder.
+ @param path: name of the folder to create
+ @type path: str
+ @param mode: permissions (posix-style) for the newly-created folder
@type mode: int
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'mkdir(%r, %r)' % (path, mode))
attr = SFTPAttributes()
attr.st_mode = mode
self._request(CMD_MKDIR, path, attr)
@@ -262,10 +301,11 @@ class SFTPClient (BaseSFTP):
"""
Remove the folder named C{path}.
- @param path: name of the folder to remove.
- @type path: string
+ @param path: name of the folder to remove
+ @type path: str
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'rmdir(%r)' % path)
self._request(CMD_RMDIR, path)
def stat(self, path):
@@ -282,12 +322,13 @@ class SFTPClient (BaseSFTP):
The fields supported are: C{st_mode}, C{st_size}, C{st_uid}, C{st_gid},
C{st_atime}, and C{st_mtime}.
- @param path: the filename to stat.
- @type path: string
- @return: an object containing attributes about the given file.
+ @param path: the filename to stat
+ @type path: str
+ @return: an object containing attributes about the given file
@rtype: SFTPAttributes
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'stat(%r)' % path)
t, msg = self._request(CMD_STAT, path)
if t != CMD_ATTRS:
raise SFTPError('Expected attributes')
@@ -299,12 +340,13 @@ class SFTPClient (BaseSFTP):
following symbolic links (shortcuts). This otherwise behaves exactly
the same as L{stat}.
- @param path: the filename to stat.
- @type path: string
- @return: an object containing attributes about the given file.
+ @param path: the filename to stat
+ @type path: str
+ @return: an object containing attributes about the given file
@rtype: SFTPAttributes
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'lstat(%r)' % path)
t, msg = self._request(CMD_LSTAT, path)
if t != CMD_ATTRS:
raise SFTPError('Expected attributes')
@@ -315,12 +357,13 @@ class SFTPClient (BaseSFTP):
Create a symbolic link (shortcut) of the C{source} path at
C{destination}.
- @param source: path of the original file.
- @type source: string
- @param dest: path of the newly created symlink.
- @type dest: string
+ @param source: path of the original file
+ @type source: str
+ @param dest: path of the newly created symlink
+ @type dest: str
"""
dest = self._adjust_cwd(dest)
+ self._log(DEBUG, 'symlink(%r, %r)' % (source, dest))
if type(source) is unicode:
source = source.encode('utf-8')
self._request(CMD_SYMLINK, source, dest)
@@ -331,12 +374,13 @@ class SFTPClient (BaseSFTP):
unix-style and identical to those used by python's C{os.chmod}
function.
- @param path: path of the file to change the permissions of.
- @type path: string
- @param mode: new permissions.
+ @param path: path of the file to change the permissions of
+ @type path: str
+ @param mode: new permissions
@type mode: int
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'chmod(%r, %r)' % (path, mode))
attr = SFTPAttributes()
attr.st_mode = mode
self._request(CMD_SETSTAT, path, attr)
@@ -348,14 +392,15 @@ class SFTPClient (BaseSFTP):
only want to change one, use L{stat} first to retrieve the current
owner and group.
- @param path: path of the file to change the owner and group of.
- @type path: string
+ @param path: path of the file to change the owner and group of
+ @type path: str
@param uid: new owner's uid
@type uid: int
@param gid: new group id
@type gid: int
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'chown(%r, %r, %r)' % (path, uid, gid))
attr = SFTPAttributes()
attr.st_uid, attr.st_gid = uid, gid
self._request(CMD_SETSTAT, path, attr)
@@ -369,31 +414,50 @@ class SFTPClient (BaseSFTP):
modified times, respectively. This bizarre API is mimicked from python
for the sake of consistency -- I apologize.
- @param path: path of the file to modify.
- @type path: string
+ @param path: path of the file to modify
+ @type path: str
@param times: C{None} or a tuple of (access time, modified time) in
- standard internet epoch time (seconds since 01 January 1970 GMT).
- @type times: tuple of int
+ standard internet epoch time (seconds since 01 January 1970 GMT)
+ @type times: tuple(int)
"""
path = self._adjust_cwd(path)
if times is None:
times = (time.time(), time.time())
+ self._log(DEBUG, 'utime(%r, %r)' % (path, times))
attr = SFTPAttributes()
attr.st_atime, attr.st_mtime = times
self._request(CMD_SETSTAT, path, attr)
+ def truncate(self, path, size):
+ """
+ Change the size of the file specified by C{path}. This usually extends
+ or shrinks the size of the file, just like the C{truncate()} method on
+ python file objects.
+
+ @param path: path of the file to modify
+ @type path: str
+ @param size: the new size of the file
+ @type size: int or long
+ """
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, 'truncate(%r, %r)' % (path, size))
+ attr = SFTPAttributes()
+ attr.st_size = size
+ self._request(CMD_SETSTAT, path, attr)
+
def readlink(self, path):
"""
Return the target of a symbolic link (shortcut). You can use
L{symlink} to create these. The result may be either an absolute or
relative pathname.
- @param path: path of the symbolic link file.
+ @param path: path of the symbolic link file
@type path: str
- @return: target path.
+ @return: target path
@rtype: str
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'readlink(%r)' % path)
t, msg = self._request(CMD_READLINK, path)
if t != CMD_NAME:
raise SFTPError('Expected name response')
@@ -411,14 +475,15 @@ class SFTPClient (BaseSFTP):
server is considering to be the "current folder" (by passing C{'.'}
as C{path}).
- @param path: path to be normalized.
+ @param path: path to be normalized
@type path: str
- @return: normalized form of the given path.
+ @return: normalized form of the given path
@rtype: str
@raise IOError: if the path can't be resolved on the server
"""
path = self._adjust_cwd(path)
+ self._log(DEBUG, 'normalize(%r)' % path)
t, msg = self._request(CMD_REALPATH, path)
if t != CMD_NAME:
raise SFTPError('Expected name response')
@@ -457,7 +522,7 @@ class SFTPClient (BaseSFTP):
"""
return self._cwd
- def put(self, localpath, remotepath):
+ def put(self, localpath, remotepath, callback=None):
"""
Copy a local file (C{localpath}) to the SFTP server as C{remotepath}.
Any exception raised by operations will be passed through. This
@@ -469,9 +534,17 @@ class SFTPClient (BaseSFTP):
@type localpath: str
@param remotepath: the destination path on the SFTP server
@type remotepath: str
+ @param callback: optional callback function that accepts the bytes
+ transferred so far and the total bytes to be transferred
+ (since 1.7.4)
+ @type callback: function(int, int)
+ @return: an object containing attributes about the given file
+ (since 1.7.4)
+ @rtype: SFTPAttributes
@since: 1.4
"""
+ file_size = os.stat(localpath).st_size
fl = file(localpath, 'rb')
fr = self.file(remotepath, 'wb')
fr.set_pipelined(True)
@@ -482,13 +555,16 @@ class SFTPClient (BaseSFTP):
break
fr.write(data)
size += len(data)
+ if callback is not None:
+ callback(size, file_size)
fl.close()
fr.close()
s = self.stat(remotepath)
if s.st_size != size:
raise IOError('size mismatch in put! %d != %d' % (s.st_size, size))
+ return s
- def get(self, remotepath, localpath):
+ def get(self, remotepath, localpath, callback=None):
"""
Copy a remote file (C{remotepath}) from the SFTP server to the local
host as C{localpath}. Any exception raised by operations will be
@@ -498,10 +574,15 @@ class SFTPClient (BaseSFTP):
@type remotepath: str
@param localpath: the destination path on the local host
@type localpath: str
+ @param callback: optional callback function that accepts the bytes
+ transferred so far and the total bytes to be transferred
+ (since 1.7.4)
+ @type callback: function(int, int)
@since: 1.4
"""
fr = self.file(remotepath, 'rb')
+ file_size = self.stat(remotepath).st_size
fr.prefetch()
fl = file(localpath, 'wb')
size = 0
@@ -511,6 +592,8 @@ class SFTPClient (BaseSFTP):
break
fl.write(data)
size += len(data)
+ if callback is not None:
+ callback(size, file_size)
fl.close()
fr.close()
s = os.stat(localpath)
@@ -552,7 +635,10 @@ class SFTPClient (BaseSFTP):
def _read_response(self, waitfor=None):
while True:
- t, data = self._read_packet()
+ try:
+ t, data = self._read_packet()
+ except EOFError, e:
+ raise SSHException('Server connection dropped: %s' % (str(e),))
msg = Message(data)
num = msg.get_int()
if num not in self._expecting:
@@ -560,7 +646,7 @@ class SFTPClient (BaseSFTP):
self._log(DEBUG, 'Unexpected response #%d' % (num,))
if waitfor is None:
# just doing a single check
- return
+ break
continue
fileobj = self._expecting[num]
del self._expecting[num]
@@ -573,7 +659,8 @@ class SFTPClient (BaseSFTP):
fileobj._async_response(t, msg)
if waitfor is None:
# just doing a single check
- return
+ break
+ return (None, None)
def _finish_responses(self, fileobj):
while fileobj in self._expecting.values():
@@ -610,6 +697,8 @@ class SFTPClient (BaseSFTP):
if (len(path) > 0) and (path[0] == '/'):
# absolute path
return path
+ if self._cwd == '/':
+ return self._cwd + path
return self._cwd + '/' + path
diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py
index f224f02..cfa7db1 100644
--- a/paramiko/sftp_file.py
+++ b/paramiko/sftp_file.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -20,7 +20,11 @@
L{SFTPFile}
"""
+from binascii import hexlify
+import socket
import threading
+import time
+
from paramiko.common import *
from paramiko.sftp import *
from paramiko.file import BufferedFile
@@ -43,12 +47,18 @@ class SFTPFile (BufferedFile):
BufferedFile._set_mode(self, mode, bufsize)
self.pipelined = False
self._prefetching = False
+ self._prefetch_done = False
+ self._prefetch_data = {}
+ self._prefetch_reads = []
self._saved_exception = None
def __del__(self):
- self.close(_async=True)
+ self._close(async=True)
+
+ def close(self):
+ self._close(async=False)
- def close(self, _async=False):
+ def _close(self, async=False):
# We allow double-close without signaling an error, because real
# Python file objects do. However, we must protect against actually
# sending multiple CMD_CLOSE packets, because after we close our
@@ -58,11 +68,12 @@ class SFTPFile (BufferedFile):
# __del__.)
if self._closed:
return
+ self.sftp._log(DEBUG, 'close(%s)' % hexlify(self.handle))
if self.pipelined:
self.sftp._finish_responses(self)
BufferedFile.close(self)
try:
- if _async:
+ if async:
# GC'd file handle could be called from an arbitrary thread -- don't wait for a response
self.sftp._async_request(type(None), CMD_CLOSE, self.handle)
else:
@@ -70,34 +81,77 @@ class SFTPFile (BufferedFile):
except EOFError:
# may have outlived the Transport connection
pass
- except IOError:
+ except (IOError, socket.error):
# may have outlived the Transport connection
pass
+ def _data_in_prefetch_requests(self, offset, size):
+ k = [i for i in self._prefetch_reads if i[0] <= offset]
+ if len(k) == 0:
+ return False
+ k.sort(lambda x, y: cmp(x[0], y[0]))
+ buf_offset, buf_size = k[-1]
+ if buf_offset + buf_size <= offset:
+ # prefetch request ends before this one begins
+ return False
+ if buf_offset + buf_size >= offset + size:
+ # inclusive
+ return True
+ # well, we have part of the request. see if another chunk has the rest.
+ return self._data_in_prefetch_requests(buf_offset + buf_size, offset + size - buf_offset - buf_size)
+
+ def _data_in_prefetch_buffers(self, offset):
+ """
+ if a block of data is present in the prefetch buffers, at the given
+ offset, return the offset of the relevant prefetch buffer. otherwise,
+ return None. this guarantees nothing about the number of bytes
+ collected in the prefetch buffer so far.
+ """
+ k = [i for i in self._prefetch_data.keys() if i <= offset]
+ if len(k) == 0:
+ return None
+ index = max(k)
+ buf_offset = offset - index
+ if buf_offset >= len(self._prefetch_data[index]):
+ # it's not here
+ return None
+ return index
+
def _read_prefetch(self, size):
+ """
+ read data out of the prefetch buffer, if possible. if the data isn't
+ in the buffer, return None. otherwise, behaves like a normal read.
+ """
# while not closed, and haven't fetched past the current position, and haven't reached EOF...
- while (self._prefetch_so_far <= self._realpos) and \
- (self._prefetch_so_far < self._prefetch_size) and not self._closed:
+ while True:
+ offset = self._data_in_prefetch_buffers(self._realpos)
+ if offset is not None:
+ break
+ if self._prefetch_done or self._closed:
+ break
self.sftp._read_response()
- self._check_exception()
- k = self._prefetch_data.keys()
- k.sort()
- while (len(k) > 0) and (k[0] + len(self._prefetch_data[k[0]]) <= self._realpos):
- # done with that block
- del self._prefetch_data[k[0]]
- k.pop(0)
- if len(k) == 0:
+ self._check_exception()
+ if offset is None:
self._prefetching = False
- return ''
- assert k[0] <= self._realpos
- buf_offset = self._realpos - k[0]
- buf_length = len(self._prefetch_data[k[0]]) - buf_offset
- return self._prefetch_data[k[0]][buf_offset : buf_offset + buf_length]
+ return None
+ prefetch = self._prefetch_data[offset]
+ del self._prefetch_data[offset]
+
+ buf_offset = self._realpos - offset
+ if buf_offset > 0:
+ self._prefetch_data[offset] = prefetch[:buf_offset]
+ prefetch = prefetch[buf_offset:]
+ if size < len(prefetch):
+ self._prefetch_data[self._realpos + size] = prefetch[size:]
+ prefetch = prefetch[:size]
+ return prefetch
def _read(self, size):
size = min(size, self.MAX_REQUEST_SIZE)
if self._prefetching:
- return self._read_prefetch(size)
+ data = self._read_prefetch(size)
+ if data is not None:
+ return data
t, msg = self.sftp._request(CMD_READ, self.handle, long(self._realpos), int(size))
if t != CMD_DATA:
raise SFTPError('Expected data')
@@ -106,8 +160,7 @@ class SFTPFile (BufferedFile):
def _write(self, data):
# may write less than requested if it would exceed max packet size
chunk = min(len(data), self.MAX_REQUEST_SIZE)
- req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos),
- str(data[:chunk]))
+ req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk]))
if not self.pipelined or self.sftp.sock.recv_ready():
t, msg = self.sftp._read_response(req)
if t != CMD_STATUS:
@@ -173,6 +226,71 @@ class SFTPFile (BufferedFile):
if t != CMD_ATTRS:
raise SFTPError('Expected attributes')
return SFTPAttributes._from_msg(msg)
+
+ def chmod(self, mode):
+ """
+ Change the mode (permissions) of this file. The permissions are
+ unix-style and identical to those used by python's C{os.chmod}
+ function.
+
+ @param mode: new permissions
+ @type mode: int
+ """
+ self.sftp._log(DEBUG, 'chmod(%s, %r)' % (hexlify(self.handle), mode))
+ attr = SFTPAttributes()
+ attr.st_mode = mode
+ self.sftp._request(CMD_FSETSTAT, self.handle, attr)
+
+ def chown(self, uid, gid):
+ """
+ Change the owner (C{uid}) and group (C{gid}) of this file. As with
+ python's C{os.chown} function, you must pass both arguments, so if you
+ only want to change one, use L{stat} first to retrieve the current
+ owner and group.
+
+ @param uid: new owner's uid
+ @type uid: int
+ @param gid: new group id
+ @type gid: int
+ """
+ self.sftp._log(DEBUG, 'chown(%s, %r, %r)' % (hexlify(self.handle), uid, gid))
+ attr = SFTPAttributes()
+ attr.st_uid, attr.st_gid = uid, gid
+ self.sftp._request(CMD_FSETSTAT, self.handle, attr)
+
+ def utime(self, times):
+ """
+ Set the access and modified times of this file. If
+ C{times} is C{None}, then the file's access and modified times are set
+ to the current time. Otherwise, C{times} must be a 2-tuple of numbers,
+ of the form C{(atime, mtime)}, which is used to set the access and
+ modified times, respectively. This bizarre API is mimicked from python
+ for the sake of consistency -- I apologize.
+
+ @param times: C{None} or a tuple of (access time, modified time) in
+ standard internet epoch time (seconds since 01 January 1970 GMT)
+ @type times: tuple(int)
+ """
+ if times is None:
+ times = (time.time(), time.time())
+ self.sftp._log(DEBUG, 'utime(%s, %r)' % (hexlify(self.handle), times))
+ attr = SFTPAttributes()
+ attr.st_atime, attr.st_mtime = times
+ self.sftp._request(CMD_FSETSTAT, self.handle, attr)
+
+ def truncate(self, size):
+ """
+ Change the size of this file. This usually extends
+ or shrinks the size of the file, just like the C{truncate()} method on
+ python file objects.
+
+ @param size: the new size of the file
+ @type size: int or long
+ """
+ self.sftp._log(DEBUG, 'truncate(%s, %r)' % (hexlify(self.handle), size))
+ attr = SFTPAttributes()
+ attr.st_size = size
+ self.sftp._request(CMD_FSETSTAT, self.handle, attr)
def check(self, hash_algorithm, offset=0, length=0, block_size=0):
"""
@@ -255,26 +373,60 @@ class SFTPFile (BufferedFile):
dramatically improve the download speed by avoiding roundtrip latency.
The file's contents are incrementally buffered in a background thread.
+ The prefetched data is stored in a buffer until read via the L{read}
+ method. Once data has been read, it's removed from the buffer. The
+ data may be read in a random order (using L{seek}); chunks of the
+ buffer that haven't been read will continue to be buffered.
+
@since: 1.5.1
"""
size = self.stat().st_size
# queue up async reads for the rest of the file
- self._prefetching = True
- self._prefetch_so_far = self._realpos
- self._prefetch_size = size
- self._prefetch_data = {}
- t = threading.Thread(target=self._prefetch)
- t.setDaemon(True)
- t.start()
-
- def _prefetch(self):
+ chunks = []
n = self._realpos
- size = self._prefetch_size
while n < size:
chunk = min(self.MAX_REQUEST_SIZE, size - n)
- self.sftp._async_request(self, CMD_READ, self.handle, long(n), int(chunk))
+ chunks.append((n, chunk))
n += chunk
+ if len(chunks) > 0:
+ self._start_prefetch(chunks)
+
+ def readv(self, chunks):
+ """
+ Read a set of blocks from the file by (offset, length). This is more
+ efficient than doing a series of L{seek} and L{read} calls, since the
+ prefetch machinery is used to retrieve all the requested blocks at
+ once.
+
+ @param chunks: a list of (offset, length) tuples indicating which
+ sections of the file to read
+ @type chunks: list(tuple(long, int))
+ @return: a list of blocks read, in the same order as in C{chunks}
+ @rtype: list(str)
+
+ @since: 1.5.4
+ """
+ self.sftp._log(DEBUG, 'readv(%s, %r)' % (hexlify(self.handle), chunks))
+ read_chunks = []
+ for offset, size in chunks:
+ # don't fetch data that's already in the prefetch buffer
+ if self._data_in_prefetch_buffers(offset) or self._data_in_prefetch_requests(offset, size):
+ continue
+
+ # break up anything larger than the max read size
+ while size > 0:
+ chunk_size = min(size, self.MAX_REQUEST_SIZE)
+ read_chunks.append((offset, chunk_size))
+ offset += chunk_size
+ size -= chunk_size
+
+ self._start_prefetch(read_chunks)
+ # now we can just devolve to a bunch of read()s :)
+ for x in chunks:
+ self.seek(x[0])
+ yield self.read(x[1])
+
### internals...
@@ -285,6 +437,21 @@ class SFTPFile (BufferedFile):
except:
return 0
+ def _start_prefetch(self, chunks):
+ self._prefetching = True
+ self._prefetch_done = False
+ self._prefetch_reads.extend(chunks)
+
+ t = threading.Thread(target=self._prefetch_thread, args=(chunks,))
+ t.setDaemon(True)
+ t.start()
+
+ def _prefetch_thread(self, chunks):
+ # do these read requests in a temporary thread because there may be
+ # a lot of them, so it may block.
+ for offset, length in chunks:
+ self.sftp._async_request(self, CMD_READ, self.handle, long(offset), int(length))
+
def _async_response(self, t, msg):
if t == CMD_STATUS:
# save exception and re-raise it on next file operation
@@ -296,8 +463,10 @@ class SFTPFile (BufferedFile):
if t != CMD_DATA:
raise SFTPError('Expected data')
data = msg.get_string()
- self._prefetch_data[self._prefetch_so_far] = data
- self._prefetch_so_far += len(data)
+ offset, length = self._prefetch_reads.pop(0)
+ self._prefetch_data[offset] = data
+ if len(self._prefetch_reads) == 0:
+ self._prefetch_done = True
def _check_exception(self):
"if there's a saved exception, raise & clear it"
diff --git a/paramiko/sftp_handle.py b/paramiko/sftp_handle.py
index e1d93e9..e976f43 100644
--- a/paramiko/sftp_handle.py
+++ b/paramiko/sftp_handle.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -35,7 +35,16 @@ class SFTPHandle (object):
Server implementations can (and should) subclass SFTPHandle to implement
features of a file handle, like L{stat} or L{chattr}.
"""
- def __init__(self):
+ def __init__(self, flags=0):
+ """
+ Create a new file handle representing a local file being served over
+ SFTP. If C{flags} is passed in, it's used to determine if the file
+ is open in append mode.
+
+ @param flags: optional flags as passed to L{SFTPServerInterface.open}
+ @type flags: int
+ """
+ self.__flags = flags
self.__name = None
# only for handles to folders:
self.__files = { }
@@ -81,15 +90,16 @@ class SFTPHandle (object):
@return: data read from the file, or an SFTP error code.
@rtype: str
"""
- if not hasattr(self, 'readfile') or (self.readfile is None):
+ readfile = getattr(self, 'readfile', None)
+ if readfile is None:
return SFTP_OP_UNSUPPORTED
try:
if self.__tell is None:
- self.__tell = self.readfile.tell()
+ self.__tell = readfile.tell()
if offset != self.__tell:
- self.readfile.seek(offset)
+ readfile.seek(offset)
self.__tell = offset
- data = self.readfile.read(length)
+ data = readfile.read(length)
except IOError, e:
self.__tell = None
return SFTPServer.convert_errno(e.errno)
@@ -116,20 +126,24 @@ class SFTPHandle (object):
@type data: str
@return: an SFTP error code like L{SFTP_OK}.
"""
- if not hasattr(self, 'writefile') or (self.writefile is None):
+ writefile = getattr(self, 'writefile', None)
+ if writefile is None:
return SFTP_OP_UNSUPPORTED
try:
- if self.__tell is None:
- self.__tell = self.writefile.tell()
- if offset != self.__tell:
- self.writefile.seek(offset)
- self.__tell = offset
- self.writefile.write(data)
- self.writefile.flush()
+ # in append mode, don't care about seeking
+ if (self.__flags & os.O_APPEND) == 0:
+ if self.__tell is None:
+ self.__tell = writefile.tell()
+ if offset != self.__tell:
+ writefile.seek(offset)
+ self.__tell = offset
+ writefile.write(data)
+ writefile.flush()
except IOError, e:
self.__tell = None
return SFTPServer.convert_errno(e.errno)
- self.__tell += len(data)
+ if self.__tell is not None:
+ self.__tell += len(data)
return SFTP_OK
def stat(self):
diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py
index 5905843..099ac12 100644
--- a/paramiko/sftp_server.py
+++ b/paramiko/sftp_server.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -66,15 +66,21 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
BaseSFTP.__init__(self)
SubsystemHandler.__init__(self, channel, name, server)
transport = channel.get_transport()
- self.logger = util.get_logger(transport.get_log_channel() + '.' +
- channel.get_name() + '.sftp')
+ self.logger = util.get_logger(transport.get_log_channel() + '.sftp')
self.ultra_debug = transport.get_hexdump()
self.next_handle = 1
# map of handle-string to SFTPHandle for files & folders:
self.file_table = { }
self.folder_table = { }
self.server = sftp_si(server, *largs, **kwargs)
-
+
+ def _log(self, level, msg):
+ if issubclass(type(msg), list):
+ for m in msg:
+ super(SFTPServer, self)._log(level, "[chan " + self.sock.get_name() + "] " + m)
+ else:
+ super(SFTPServer, self)._log(level, "[chan " + self.sock.get_name() + "] " + msg)
+
def start_subsystem(self, name, transport, channel):
self.sock = channel
self._log(DEBUG, 'Started sftp server on channel %s' % repr(channel))
@@ -92,10 +98,20 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
return
msg = Message(data)
request_number = msg.get_int()
- self._process(t, request_number, msg)
+ try:
+ self._process(t, request_number, msg)
+ except Exception, e:
+ self._log(DEBUG, 'Exception in server processing: ' + str(e))
+ self._log(DEBUG, util.tb_strings())
+ # send some kind of failure message, at least
+ try:
+ self._send_status(request_number, SFTP_FAILURE)
+ except:
+ pass
def finish_subsystem(self):
self.server.session_ended()
+ super(SFTPServer, self).finish_subsystem()
# close any file handles that were left open (so we can return them to the OS quickly)
for f in self.file_table.itervalues():
f.close()
@@ -118,7 +134,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
if e == errno.EACCES:
# permission denied
return SFTP_PERMISSION_DENIED
- elif e == errno.ENOENT:
+ elif (e == errno.ENOENT) or (e == errno.ENOTDIR):
# no such file
return SFTP_NO_SUCH_FILE
else:
@@ -141,12 +157,16 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
@param attr: attributes to change.
@type attr: L{SFTPAttributes}
"""
- if attr._flags & attr.FLAG_PERMISSIONS:
- os.chmod(filename, attr.st_mode)
- if attr._flags & attr.FLAG_UIDGID:
- os.chown(filename, attr.st_uid, attr.st_gid)
+ if sys.platform != 'win32':
+ # mode operations are meaningless on win32
+ if attr._flags & attr.FLAG_PERMISSIONS:
+ os.chmod(filename, attr.st_mode)
+ if attr._flags & attr.FLAG_UIDGID:
+ os.chown(filename, attr.st_uid, attr.st_gid)
if attr._flags & attr.FLAG_AMTIME:
os.utime(filename, (attr.st_atime, attr.st_mtime))
+ if attr._flags & attr.FLAG_SIZE:
+ open(filename, 'w+').truncate(attr.st_size)
set_file_attr = staticmethod(set_file_attr)
@@ -184,8 +204,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
def _send_status(self, request_number, code, desc=None):
if desc is None:
- desc = SFTP_DESC[code]
- self._response(request_number, CMD_STATUS, code, desc)
+ try:
+ desc = SFTP_DESC[code]
+ except IndexError:
+ desc = 'Unknown'
+ # some clients expect a "langauge" tag at the end (but don't mind it being blank)
+ self._response(request_number, CMD_STATUS, code, desc, '')
def _open_folder(self, request_number, path):
resp = self.server.list_folder(path)
@@ -222,7 +246,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
start = msg.get_int64()
length = msg.get_int64()
block_size = msg.get_int()
- if not self.file_table.has_key(handle):
+ if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
f = self.file_table[handle]
@@ -246,29 +270,29 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, SFTP_FAILURE, 'Block size too small')
return
- sum = ''
+ sum_out = ''
offset = start
while offset < start + length:
blocklen = min(block_size, start + length - offset)
# don't try to read more than about 64KB at a time
chunklen = min(blocklen, 65536)
count = 0
- hash = alg.new()
+ hash_obj = alg.new()
while count < blocklen:
data = f.read(offset, chunklen)
if not type(data) is str:
self._send_status(request_number, data, 'Unable to hash file')
return
- hash.update(data)
+ hash_obj.update(data)
count += len(data)
offset += count
- sum += hash.digest()
+ sum_out += hash_obj.digest()
msg = Message()
msg.add_int(request_number)
msg.add_string('check-file')
msg.add_string(algname)
- msg.add_bytes(sum)
+ msg.add_bytes(sum_out)
self._send_packet(CMD_EXTENDED_REPLY, str(msg))
def _convert_pflags(self, pflags):
@@ -298,11 +322,11 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_handle_response(request_number, self.server.open(path, flags, attr))
elif t == CMD_CLOSE:
handle = msg.get_string()
- if self.folder_table.has_key(handle):
+ if handle in self.folder_table:
del self.folder_table[handle]
self._send_status(request_number, SFTP_OK)
return
- if self.file_table.has_key(handle):
+ if handle in self.file_table:
self.file_table[handle].close()
del self.file_table[handle]
self._send_status(request_number, SFTP_OK)
@@ -312,7 +336,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
handle = msg.get_string()
offset = msg.get_int64()
length = msg.get_int()
- if not self.file_table.has_key(handle):
+ if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
data = self.file_table[handle].read(offset, length)
@@ -327,7 +351,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
handle = msg.get_string()
offset = msg.get_int64()
data = msg.get_string()
- if not self.file_table.has_key(handle):
+ if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
self._send_status(request_number, self.file_table[handle].write(offset, data))
@@ -351,7 +375,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
return
elif t == CMD_READDIR:
handle = msg.get_string()
- if not self.folder_table.has_key(handle):
+ if handle not in self.folder_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
folder = self.folder_table[handle]
@@ -372,7 +396,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, resp)
elif t == CMD_FSTAT:
handle = msg.get_string()
- if not self.file_table.has_key(handle):
+ if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
resp = self.file_table[handle].stat()
@@ -387,7 +411,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
elif t == CMD_FSETSTAT:
handle = msg.get_string()
attr = SFTPAttributes._from_msg(msg)
- if not self.file_table.has_key(handle):
+ if handle not in self.file_table:
self._response(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
self._send_status(request_number, self.file_table[handle].chattr(attr))
@@ -412,7 +436,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
if tag == 'check-file':
self._check_file(request_number, msg)
else:
- send._send_status(request_number, SFTP_OP_UNSUPPORTED)
+ self._send_status(request_number, SFTP_OP_UNSUPPORTED)
else:
self._send_status(request_number, SFTP_OP_UNSUPPORTED)
diff --git a/paramiko/sftp_si.py b/paramiko/sftp_si.py
index 16005d4..47dd25d 100644
--- a/paramiko/sftp_si.py
+++ b/paramiko/sftp_si.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -36,6 +36,9 @@ class SFTPServerInterface (object):
SFTP sessions). However, raising an exception will usually cause the SFTP
session to abruptly end, so you will usually want to catch exceptions and
return an appropriate error code.
+
+ All paths are in string form instead of unicode because not all SFTP
+ clients & servers obey the requirement that paths be encoded in UTF-8.
"""
def __init__ (self, server, *largs, **kwargs):
@@ -268,9 +271,13 @@ class SFTPServerInterface (object):
The default implementation returns C{os.path.normpath('/' + path)}.
"""
if os.path.isabs(path):
- return os.path.normpath(path)
+ out = os.path.normpath(path)
else:
- return os.path.normpath('/' + path)
+ out = os.path.normpath('/' + path)
+ if sys.platform == 'win32':
+ # on windows, normalize backslashes to sftp/posix format
+ out = out.replace('\\', '/')
+ return out
def readlink(self, path):
"""
diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py
index 900d4a0..e3120bb 100644
--- a/paramiko/ssh_exception.py
+++ b/paramiko/ssh_exception.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -28,14 +28,25 @@ class SSHException (Exception):
pass
-class PasswordRequiredException (SSHException):
+class AuthenticationException (SSHException):
+ """
+ Exception raised when authentication failed for some reason. It may be
+ possible to retry with different credentials. (Other classes specify more
+ specific reasons.)
+
+ @since: 1.6
+ """
+ pass
+
+
+class PasswordRequiredException (AuthenticationException):
"""
Exception raised when a password is needed to unlock a private key file.
"""
pass
-class BadAuthenticationType (SSHException):
+class BadAuthenticationType (AuthenticationException):
"""
Exception raised when an authentication type (like password) is used, but
the server isn't allowing that type. (It may only allow public-key, for
@@ -51,19 +62,54 @@ class BadAuthenticationType (SSHException):
allowed_types = []
def __init__(self, explanation, types):
- SSHException.__init__(self, explanation)
+ AuthenticationException.__init__(self, explanation)
self.allowed_types = types
def __str__(self):
return SSHException.__str__(self) + ' (allowed_types=%r)' % self.allowed_types
-class PartialAuthentication (SSHException):
+class PartialAuthentication (AuthenticationException):
"""
An internal exception thrown in the case of partial authentication.
"""
allowed_types = []
def __init__(self, types):
- SSHException.__init__(self, 'partial authentication')
+ AuthenticationException.__init__(self, 'partial authentication')
self.allowed_types = types
+
+
+class ChannelException (SSHException):
+ """
+ Exception raised when an attempt to open a new L{Channel} fails.
+
+ @ivar code: the error code returned by the server
+ @type code: int
+
+ @since: 1.6
+ """
+ def __init__(self, code, text):
+ SSHException.__init__(self, text)
+ self.code = code
+
+
+class BadHostKeyException (SSHException):
+ """
+ The host key given by the SSH server did not match what we were expecting.
+
+ @ivar hostname: the hostname of the SSH server
+ @type hostname: str
+ @ivar key: the host key presented by the server
+ @type key: L{PKey}
+ @ivar expected_key: the host key expected
+ @type expected_key: L{PKey}
+
+ @since: 1.6
+ """
+ def __init__(self, hostname, got_key, expected_key):
+ SSHException.__init__(self, 'Host key for server %s does not match!' % hostname)
+ self.hostname = hostname
+ self.key = got_key
+ self.expected_key = expected_key
+
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 8714a96..a18e05b 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -30,19 +30,20 @@ import time
import weakref
from paramiko import util
+from paramiko.auth_handler import AuthHandler
+from paramiko.channel import Channel
from paramiko.common import *
from paramiko.compress import ZlibCompressor, ZlibDecompressor
-from paramiko.ssh_exception import SSHException, BadAuthenticationType
-from paramiko.message import Message
-from paramiko.channel import Channel
-from paramiko.sftp_client import SFTPClient
-from paramiko.packet import Packetizer, NeedRekeyException
-from paramiko.rsakey import RSAKey
from paramiko.dsskey import DSSKey
-from paramiko.kex_group1 import KexGroup1
from paramiko.kex_gex import KexGex
+from paramiko.kex_group1 import KexGroup1
+from paramiko.message import Message
+from paramiko.packet import Packetizer, NeedRekeyException
from paramiko.primes import ModulusPack
-from paramiko.auth_handler import AuthHandler
+from paramiko.rsakey import RSAKey
+from paramiko.server import ServerInterface
+from paramiko.sftp_client import SFTPClient
+from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException
# these come from PyCrypt
# http://www.amk.ca/python/writing/pycrypt/
@@ -50,7 +51,7 @@ from paramiko.auth_handler import AuthHandler
# PyCrypt compiled for Win32 can be downloaded from the HashTar homepage:
# http://nitace.bsd.uchicago.edu:8080/hashtar
from Crypto.Cipher import Blowfish, AES, DES3
-from Crypto.Hash import SHA, MD5, HMAC
+from Crypto.Hash import SHA, MD5
# for thread cleanup
@@ -73,8 +74,6 @@ class SecurityOptions (object):
If you try to add an algorithm that paramiko doesn't recognize,
C{ValueError} will be raised. If you try to assign something besides a
tuple to one of the fields, C{TypeError} will be raised.
-
- @since: ivysaur
"""
__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ]
@@ -110,7 +109,8 @@ class SecurityOptions (object):
if type(x) is not tuple:
raise TypeError('expected tuple or list')
possible = getattr(self._transport, orig).keys()
- if len(filter(lambda n: n not in possible, x)) > 0:
+ forbidden = filter(lambda n: n not in possible, x)
+ if len(forbidden) > 0:
raise ValueError('unknown cipher')
setattr(self._transport, name, x)
@@ -140,6 +140,51 @@ class SecurityOptions (object):
"Compression algorithms")
+class ChannelMap (object):
+ def __init__(self):
+ # (id -> Channel)
+ self._map = weakref.WeakValueDictionary()
+ self._lock = threading.Lock()
+
+ def put(self, chanid, chan):
+ self._lock.acquire()
+ try:
+ self._map[chanid] = chan
+ finally:
+ self._lock.release()
+
+ def get(self, chanid):
+ self._lock.acquire()
+ try:
+ return self._map.get(chanid, None)
+ finally:
+ self._lock.release()
+
+ def delete(self, chanid):
+ self._lock.acquire()
+ try:
+ try:
+ del self._map[chanid]
+ except KeyError:
+ pass
+ finally:
+ self._lock.release()
+
+ def values(self):
+ self._lock.acquire()
+ try:
+ return self._map.values()
+ finally:
+ self._lock.release()
+
+ def __len__(self):
+ self._lock.acquire()
+ try:
+ return len(self._map)
+ finally:
+ self._lock.release()
+
+
class Transport (threading.Thread):
"""
An SSH Transport attaches to a stream (usually a socket), negotiates an
@@ -149,7 +194,7 @@ class Transport (threading.Thread):
"""
_PROTO_ID = '2.0'
- _CLIENT_ID = 'paramiko_1.5.2'
+ _CLIENT_ID = 'paramiko_1.7.4'
_preferred_ciphers = ( 'aes128-cbc', 'blowfish-cbc', 'aes256-cbc', '3des-cbc' )
_preferred_macs = ( 'hmac-sha1', 'hmac-md5', 'hmac-sha1-96', 'hmac-md5-96' )
@@ -245,25 +290,41 @@ class Transport (threading.Thread):
self.sock.settimeout(0.1)
except AttributeError:
pass
+
# negotiated crypto parameters
self.packetizer = Packetizer(sock)
self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID
self.remote_version = ''
self.local_cipher = self.remote_cipher = ''
self.local_kex_init = self.remote_kex_init = None
+ self.local_mac = self.remote_mac = None
+ self.local_compression = self.remote_compression = None
self.session_id = None
- # /negotiated crypto parameters
- self.expected_packet = 0
+ self.host_key_type = None
+ self.host_key = None
+
+ # state used during negotiation
+ self.kex_engine = None
+ self.H = None
+ self.K = None
+
self.active = False
self.initial_kex_done = False
self.in_kex = False
+ self.authenticated = False
+ self._expected_packet = tuple()
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
- self.channels = weakref.WeakValueDictionary() # (id -> Channel)
+
+ # tracking open channels
+ self._channels = ChannelMap()
self.channel_events = { } # (id -> Event)
self.channels_seen = { } # (id -> True)
- self.channel_counter = 1
+ self._channel_counter = 1
self.window_size = 65536
self.max_packet_size = 34816
+ self._x11_handler = None
+ self._tcp_handler = None
+
self.saved_exception = None
self.clear_to_send = threading.Event()
self.clear_to_send_lock = threading.Lock()
@@ -271,9 +332,10 @@ class Transport (threading.Thread):
self.logger = util.get_logger(self.log_name)
self.packetizer.set_log(self.logger)
self.auth_handler = None
- self.authenticated = False
- # user-defined event callbacks:
- self.completion_event = None
+ self.global_response = None # response Message from an arbitrary global request
+ self.completion_event = None # user-defined event callbacks
+ self.banner_timeout = 15 # how long (seconds) to wait for the SSH banner
+
# server mode:
self.server_mode = False
self.server_object = None
@@ -282,9 +344,6 @@ class Transport (threading.Thread):
self.server_accept_cv = threading.Condition(self.lock)
self.subsystem_table = { }
- def __del__(self):
- self.close()
-
def __repr__(self):
"""
Returns a string representation of this object, for debugging.
@@ -299,16 +358,26 @@ class Transport (threading.Thread):
out += ' (cipher %s, %d bits)' % (self.local_cipher,
self._cipher_info[self.local_cipher]['key-size'] * 8)
if self.is_authenticated():
- if len(self.channels) == 1:
- out += ' (active; 1 open channel)'
- else:
- out += ' (active; %d open channels)' % len(self.channels)
+ out += ' (active; %d open channel(s))' % len(self._channels)
elif self.initial_kex_done:
out += ' (connected; awaiting auth)'
else:
out += ' (connecting)'
out += '>'
return out
+
+ def atfork(self):
+ """
+ Terminate this Transport without closing the session. On posix
+ systems, if a Transport is open during process forking, both parent
+ and child will share the underlying socket, but only one process can
+ use the connection (without corrupting the session). Use this method
+ to clean up a Transport object without disrupting the other process.
+
+ @since: 1.5.3
+ """
+ self.sock.close()
+ self.close()
def get_security_options(self):
"""
@@ -319,8 +388,6 @@ class Transport (threading.Thread):
@return: an object that can be used to change the preferred algorithms
for encryption, digest (hash), public key, and key exchange.
@rtype: L{SecurityOptions}
-
- @since: ivysaur
"""
return SecurityOptions(self)
@@ -471,7 +538,8 @@ class Transport (threading.Thread):
try:
return self.server_key_dict[self.host_key_type]
except KeyError:
- return None
+ pass
+ return None
def load_server_moduli(filename=None):
"""
@@ -496,8 +564,6 @@ class Transport (threading.Thread):
@return: True if a moduli file was successfully loaded; False
otherwise.
@rtype: bool
-
- @since: doduo
@note: This has no effect when used in client mode.
"""
@@ -521,14 +587,13 @@ class Transport (threading.Thread):
"""
Close this session, and any open channels that are tied to it.
"""
+ if not self.active:
+ return
self.active = False
- # since this may be called from __del__, can't assume any attributes exist
- try:
- self.packetizer.close()
- for chan in self.channels.values():
- chan._unlink()
- except AttributeError:
- pass
+ self.packetizer.close()
+ self.join()
+ for chan in self._channels.values():
+ chan._unlink()
def get_remote_server_key(self):
"""
@@ -541,7 +606,7 @@ class Transport (threading.Thread):
@raise SSHException: if no session is currently active.
- @return: public key of the remote server.
+ @return: public key of the remote server
@rtype: L{PKey <pkey.PKey>}
"""
if (not self.active) or (not self.initial_kex_done):
@@ -553,7 +618,7 @@ class Transport (threading.Thread):
Return true if this session is active (open).
@return: True if the session is still active (open); False if the
- session is closed.
+ session is closed
@rtype: bool
"""
return self.active
@@ -563,12 +628,43 @@ class Transport (threading.Thread):
Request a new channel to the server, of type C{"session"}. This
is just an alias for C{open_channel('session')}.
- @return: a new L{Channel} on success, or C{None} if the request is
- rejected or the session ends prematurely.
+ @return: a new L{Channel}
@rtype: L{Channel}
+
+ @raise SSHException: if the request is rejected or the session ends
+ prematurely
"""
return self.open_channel('session')
+ def open_x11_channel(self, src_addr=None):
+ """
+ Request a new channel to the client, of type C{"x11"}. This
+ is just an alias for C{open_channel('x11', src_addr=src_addr)}.
+
+ @param src_addr: the source address of the x11 server (port is the
+ x11 port, ie. 6010)
+ @type src_addr: (str, int)
+ @return: a new L{Channel}
+ @rtype: L{Channel}
+
+ @raise SSHException: if the request is rejected or the session ends
+ prematurely
+ """
+ return self.open_channel('x11', src_addr=src_addr)
+
+ def open_forwarded_tcpip_channel(self, (src_addr, src_port), (dest_addr, dest_port)):
+ """
+ Request a new channel back to the client, of type C{"forwarded-tcpip"}.
+ This is used after a client has requested port forwarding, for sending
+ incoming connections back to the client.
+
+ @param src_addr: originator's address
+ @param src_port: originator's port
+ @param dest_addr: local (server) connected address
+ @param dest_port: local (server) connected port
+ """
+ return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port))
+
def open_channel(self, kind, dest_addr=None, src_addr=None):
"""
Request a new channel to the server. L{Channel}s are socket-like
@@ -577,18 +673,20 @@ class Transport (threading.Thread):
L{connect} or L{start_client}) and authenticating.
@param kind: the kind of channel requested (usually C{"session"},
- C{"forwarded-tcpip"} or C{"direct-tcpip"}).
+ C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"})
@type kind: str
@param dest_addr: the destination address of this port forwarding,
if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
- for other channel types).
+ for other channel types)
@type dest_addr: (str, int)
@param src_addr: the source address of this port forwarding, if
- C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}.
+ C{kind} is C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"}
@type src_addr: (str, int)
- @return: a new L{Channel} on success, or C{None} if the request is
- rejected or the session ends prematurely.
+ @return: a new L{Channel} on success
@rtype: L{Channel}
+
+ @raise SSHException: if the request is rejected or the session ends
+ prematurely
"""
chan = None
if not self.active:
@@ -596,11 +694,7 @@ class Transport (threading.Thread):
return None
self.lock.acquire()
try:
- chanid = self.channel_counter
- while self.channels.has_key(chanid):
- self.channel_counter = (self.channel_counter + 1) & 0xffffff
- chanid = self.channel_counter
- self.channel_counter = (self.channel_counter + 1) & 0xffffff
+ chanid = self._next_channel()
m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN))
m.add_string(kind)
@@ -612,7 +706,11 @@ class Transport (threading.Thread):
m.add_int(dest_addr[1])
m.add_string(src_addr[0])
m.add_int(src_addr[1])
- self.channels[chanid] = chan = Channel(chanid)
+ elif kind == 'x11':
+ m.add_string(src_addr[0])
+ m.add_int(src_addr[1])
+ chan = Channel(chanid)
+ self._channels.put(chanid, chan)
self.channel_events[chanid] = event = threading.Event()
self.channels_seen[chanid] = True
chan._set_transport(self)
@@ -620,20 +718,84 @@ class Transport (threading.Thread):
finally:
self.lock.release()
self._send_user_message(m)
- while 1:
+ while True:
event.wait(0.1);
if not self.active:
- return None
+ e = self.get_exception()
+ if e is None:
+ e = SSHException('Unable to open channel.')
+ raise e
if event.isSet():
break
- try:
- self.lock.acquire()
- if not self.channels.has_key(chanid):
- chan = None
- finally:
- self.lock.release()
- return chan
-
+ chan = self._channels.get(chanid)
+ if chan is not None:
+ return chan
+ e = self.get_exception()
+ if e is None:
+ e = SSHException('Unable to open channel.')
+ raise e
+
+ def request_port_forward(self, address, port, handler=None):
+ """
+ Ask the server to forward TCP connections from a listening port on
+ the server, across this SSH session.
+
+ If a handler is given, that handler is called from a different thread
+ whenever a forwarded connection arrives. The handler parameters are::
+
+ handler(channel, (origin_addr, origin_port), (server_addr, server_port))
+
+ where C{server_addr} and C{server_port} are the address and port that
+ the server was listening on.
+
+ If no handler is set, the default behavior is to send new incoming
+ forwarded connections into the accept queue, to be picked up via
+ L{accept}.
+
+ @param address: the address to bind when forwarding
+ @type address: str
+ @param port: the port to forward, or 0 to ask the server to allocate
+ any port
+ @type port: int
+ @param handler: optional handler for incoming forwarded connections
+ @type handler: function(Channel, (str, int), (str, int))
+ @return: the port # allocated by the server
+ @rtype: int
+
+ @raise SSHException: if the server refused the TCP forward request
+ """
+ if not self.active:
+ raise SSHException('SSH session not active')
+ address = str(address)
+ port = int(port)
+ response = self.global_request('tcpip-forward', (address, port), wait=True)
+ if response is None:
+ raise SSHException('TCP forwarding request denied')
+ if port == 0:
+ port = response.get_int()
+ if handler is None:
+ def default_handler(channel, (src_addr, src_port), (dest_addr, dest_port)):
+ self._queue_incoming_channel(channel)
+ handler = default_handler
+ self._tcp_handler = handler
+ return port
+
+ def cancel_port_forward(self, address, port):
+ """
+ Ask the server to cancel a previous port-forwarding request. No more
+ connections to the given address & port will be forwarded across this
+ ssh connection.
+
+ @param address: the address to stop forwarding
+ @type address: str
+ @param port: the port to stop forwarding
+ @type port: int
+ """
+ if not self.active:
+ return
+ self._tcp_handler = None
+ self.global_request('cancel-tcpip-forward', (address, port), wait=True)
+
def open_sftp_client(self):
"""
Create an SFTP client channel from an open transport. On success,
@@ -656,8 +818,6 @@ class Transport (threading.Thread):
@param bytes: the number of random bytes to send in the payload of the
ignored packet -- defaults to a random number from 10 to 41.
@type bytes: int
-
- @since: fearow
"""
m = Message()
m.add_byte(chr(MSG_IGNORE))
@@ -674,22 +834,23 @@ class Transport (threading.Thread):
bytes sent or received, but this method gives you the option of forcing
new keys whenever you want. Negotiating new keys causes a pause in
traffic both ways as the two sides swap keys and do computations. This
- method returns when the session has switched to new keys, or the
- session has died mid-negotiation.
+ method returns when the session has switched to new keys.
- @return: True if the renegotiation was successful, and the link is
- using new keys; False if the session dropped during renegotiation.
- @rtype: bool
+ @raise SSHException: if the key renegotiation failed (which causes the
+ session to end)
"""
self.completion_event = threading.Event()
self._send_kex_init()
- while 1:
- self.completion_event.wait(0.1);
+ while True:
+ self.completion_event.wait(0.1)
if not self.active:
- return False
+ e = self.get_exception()
+ if e is not None:
+ raise e
+ raise SSHException('Negotiation failed.')
if self.completion_event.isSet():
break
- return True
+ return
def set_keepalive(self, interval):
"""
@@ -701,11 +862,9 @@ class Transport (threading.Thread):
@param interval: seconds to wait before sending a keepalive packet (or
0 to disable keepalives).
@type interval: int
-
- @since: fearow
"""
self.packetizer.set_keepalive(interval,
- lambda x=self: x.global_request('keepalive@lag.net', wait=False))
+ lambda x=weakref.proxy(self): x.global_request('keepalive@lag.net', wait=False))
def global_request(self, kind, data=None, wait=True):
"""
@@ -724,8 +883,6 @@ class Transport (threading.Thread):
request was successful (or an empty L{Message} if C{wait} was
C{False}); C{None} if the request was denied.
@rtype: L{Message}
-
- @since: fearow
"""
if wait:
self.completion_event = threading.Event()
@@ -807,8 +964,6 @@ class Transport (threading.Thread):
@raise SSHException: if the SSH2 negotiation fails, the host key
supplied by the server is incorrect, or authentication fails.
-
- @since: doduo
"""
if hostkey is not None:
self._preferred_keys = [ hostkey.get_name() ]
@@ -896,8 +1051,6 @@ class Transport (threading.Thread):
@return: username that was authenticated, or C{None}.
@rtype: string
-
- @since: fearow
"""
if not self.active or (self.auth_handler is None):
return None
@@ -958,9 +1111,9 @@ class Transport (threading.Thread):
step. Otherwise, in the normal case, an empty list is returned.
@param username: the username to authenticate as
- @type username: string
+ @type username: str
@param password: the password to authenticate with
- @type password: string
+ @type password: str or unicode
@param event: an event to trigger when the authentication attempt is
complete (whether it was successful or not)
@type event: threading.Event
@@ -974,8 +1127,9 @@ class Transport (threading.Thread):
@raise BadAuthenticationType: if password authentication isn't
allowed by the server for this user (and no event was passed in)
- @raise SSHException: if the authentication failed (and no event was
- passed in)
+ @raise AuthenticationException: if the authentication failed (and no
+ event was passed in)
+ @raise SSHException: if there was a network error
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to send the password unless we're on a secure link
@@ -993,7 +1147,7 @@ class Transport (threading.Thread):
return self.auth_handler.wait_for_response(my_event)
except BadAuthenticationType, x:
# if password auth isn't allowed, but keyboard-interactive *is*, try to fudge it
- if not fallback or not 'keyboard-interactive' in x.allowed_types:
+ if not fallback or ('keyboard-interactive' not in x.allowed_types):
raise
try:
def handler(title, instructions, fields):
@@ -1010,6 +1164,7 @@ class Transport (threading.Thread):
except SSHException, ignored:
# attempt failed; just raise the original exception
raise x
+ return None
def auth_publickey(self, username, key, event=None):
"""
@@ -1037,13 +1192,14 @@ class Transport (threading.Thread):
complete (whether it was successful or not)
@type event: threading.Event
@return: list of auth types permissible for the next stage of
- authentication (normally empty).
+ authentication (normally empty)
@rtype: list
@raise BadAuthenticationType: if public-key authentication isn't
- allowed by the server for this user (and no event was passed in).
- @raise SSHException: if the authentication failed (and no event was
- passed in).
+ allowed by the server for this user (and no event was passed in)
+ @raise AuthenticationException: if the authentication failed (and no
+ event was passed in)
+ @raise SSHException: if there was a network error
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
@@ -1100,7 +1256,8 @@ class Transport (threading.Thread):
@raise BadAuthenticationType: if public-key authentication isn't
allowed by the server for this user
- @raise SSHException: if the authentication failed
+ @raise AuthenticationException: if the authentication failed
+ @raise SSHException: if there was a network error
@since: 1.5
"""
@@ -1119,13 +1276,14 @@ class Transport (threading.Thread):
(See the C{logging} module for more info.) SSH Channels will log
to a sub-channel of the one specified.
- @param name: new channel name for logging.
+ @param name: new channel name for logging
@type name: str
@since: 1.1
"""
self.log_name = name
self.logger = util.get_logger(name)
+ self.packetizer.set_log(self.logger)
def get_log_channel(self):
"""
@@ -1166,8 +1324,7 @@ class Transport (threading.Thread):
"""
Turn on/off compression. This will only have an affect before starting
the transport (ie before calling L{connect}, etc). By default,
- compression is off since it negatively affects interactive sessions
- and is not fully tested.
+ compression is off since it negatively affects interactive sessions.
@param compress: C{True} to ask the remote client/server to compress
traffic; C{False} to refuse compression
@@ -1179,6 +1336,21 @@ class Transport (threading.Thread):
self._preferred_compression = ( 'zlib@openssh.com', 'zlib', 'none' )
else:
self._preferred_compression = ( 'none', )
+
+ def getpeername(self):
+ """
+ Return the address of the remote side of this Transport, if possible.
+ This is effectively a wrapper around C{'getpeername'} on the underlying
+ socket. If the socket-like object has no C{'getpeername'} method,
+ then C{("unknown", 0)} is returned.
+
+ @return: the address if the remote host, if known
+ @rtype: tuple(str, int)
+ """
+ gp = getattr(self.sock, 'getpeername', None)
+ if gp is None:
+ return ('unknown', 0)
+ return gp()
def stop_thread(self):
self.active = False
@@ -1188,25 +1360,29 @@ class Transport (threading.Thread):
### internals...
- def _log(self, level, msg):
+ def _log(self, level, msg, *args):
if issubclass(type(msg), list):
for m in msg:
self.logger.log(level, m)
else:
- self.logger.log(level, msg)
+ self.logger.log(level, msg, *args)
def _get_modulus_pack(self):
"used by KexGex to find primes for group exchange"
return self._modulus_pack
+ def _next_channel(self):
+ "you are holding the lock"
+ chanid = self._channel_counter
+ while self._channels.get(chanid) is not None:
+ self._channel_counter = (self._channel_counter + 1) & 0xffffff
+ chanid = self._channel_counter
+ self._channel_counter = (self._channel_counter + 1) & 0xffffff
+ return chanid
+
def _unlink_channel(self, chanid):
"used by a Channel to remove itself from the active channel list"
- try:
- self.lock.acquire()
- if self.channels.has_key(chanid):
- del self.channels[chanid]
- finally:
- self.lock.release()
+ self._channels.delete(chanid)
def _send_message(self, data):
self.packetizer.send_message(data)
@@ -1237,9 +1413,9 @@ class Transport (threading.Thread):
if self.session_id == None:
self.session_id = h
- def _expect_packet(self, type):
+ def _expect_packet(self, *ptypes):
"used by a kex object to register the next packet type it expects to see"
- self.expected_packet = type
+ self._expected_packet = tuple(ptypes)
def _verify_key(self, host_key, sig):
key = self._key_info[self.host_key_type](Message(host_key))
@@ -1262,16 +1438,34 @@ class Transport (threading.Thread):
m.add_mpint(self.K)
m.add_bytes(self.H)
m.add_bytes(sofar)
- hash = SHA.new(str(m)).digest()
- out += hash
- sofar += hash
+ digest = SHA.new(str(m)).digest()
+ out += digest
+ sofar += digest
return out[:nbytes]
def _get_cipher(self, name, key, iv):
- if not self._cipher_info.has_key(name):
+ if name not in self._cipher_info:
raise SSHException('Unknown client cipher ' + name)
return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv)
+ def _set_x11_handler(self, handler):
+ # only called if a channel has turned on x11 forwarding
+ if handler is None:
+ # by default, use the same mechanism as accept()
+ def default_handler(channel, (src_addr, src_port)):
+ self._queue_incoming_channel(channel)
+ self._x11_handler = default_handler
+ else:
+ self._x11_handler = handler
+
+ def _queue_incoming_channel(self, channel):
+ self.lock.acquire()
+ try:
+ self.server_accepts.append(channel)
+ self.server_accept_cv.notify()
+ finally:
+ self.lock.release()
+
def run(self):
# (use the exposed "run" method, because if we specify a thread target
# of a private method, threading.Thread will keep a reference to it
@@ -1288,7 +1482,7 @@ class Transport (threading.Thread):
self.packetizer.write_all(self.local_version + '\r\n')
self._check_banner()
self._send_kex_init()
- self.expected_packet = MSG_KEXINIT
+ self._expect_packet(MSG_KEXINIT)
while self.active:
if self.packetizer.need_rekey() and not self.in_kex:
@@ -1307,27 +1501,28 @@ class Transport (threading.Thread):
elif ptype == MSG_DEBUG:
self._parse_debug(m)
continue
- if self.expected_packet != 0:
- if ptype != self.expected_packet:
- raise SSHException('Expecting packet %d, got %d' % (self.expected_packet, ptype))
- self.expected_packet = 0
+ if len(self._expected_packet) > 0:
+ if ptype not in self._expected_packet:
+ raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype))
+ self._expected_packet = tuple()
if (ptype >= 30) and (ptype <= 39):
self.kex_engine.parse_next(ptype, m)
continue
- if self._handler_table.has_key(ptype):
+ if ptype in self._handler_table:
self._handler_table[ptype](self, m)
- elif self._channel_handler_table.has_key(ptype):
+ elif ptype in self._channel_handler_table:
chanid = m.get_int()
- if self.channels.has_key(chanid):
- self._channel_handler_table[ptype](self.channels[chanid], m)
- elif self.channels_seen.has_key(chanid):
+ chan = self._channels.get(chanid)
+ if chan is not None:
+ self._channel_handler_table[ptype](chan, m)
+ elif chanid in self.channels_seen:
self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid)
else:
self._log(ERROR, 'Channel request for unknown channel %d' % chanid)
self.active = False
self.packetizer.close()
- elif (self.auth_handler is not None) and self.auth_handler._handler_table.has_key(ptype):
+ elif (self.auth_handler is not None) and (ptype in self.auth_handler._handler_table):
self.auth_handler._handler_table[ptype](self.auth_handler, m)
else:
self._log(WARNING, 'Oops, unhandled type %d' % ptype)
@@ -1355,7 +1550,7 @@ class Transport (threading.Thread):
self._log(ERROR, util.tb_strings())
self.saved_exception = e
_active_threads.remove(self)
- for chan in self.channels.values():
+ for chan in self._channels.values():
chan._unlink()
if self.active:
self.active = False
@@ -1366,6 +1561,11 @@ class Transport (threading.Thread):
self.auth_handler.abort()
for event in self.channel_events.values():
event.set()
+ try:
+ self.lock.acquire()
+ self.server_accept_cv.notify()
+ finally:
+ self.lock.release()
self.sock.close()
@@ -1388,30 +1588,31 @@ class Transport (threading.Thread):
def _check_banner(self):
# this is slow, but we only have to do it once
for i in range(5):
- # give them 5 seconds for the first line, then just 2 seconds each additional line
+ # give them 15 seconds for the first line, then just 2 seconds
+ # each additional line. (some sites have very high latency.)
if i == 0:
- timeout = 5
+ timeout = self.banner_timeout
else:
timeout = 2
try:
- buffer = self.packetizer.readline(timeout)
+ buf = self.packetizer.readline(timeout)
except Exception, x:
raise SSHException('Error reading SSH protocol banner' + str(x))
- if buffer[:4] == 'SSH-':
+ if buf[:4] == 'SSH-':
break
- self._log(DEBUG, 'Banner: ' + buffer)
- if buffer[:4] != 'SSH-':
- raise SSHException('Indecipherable protocol version "' + buffer + '"')
+ self._log(DEBUG, 'Banner: ' + buf)
+ if buf[:4] != 'SSH-':
+ raise SSHException('Indecipherable protocol version "' + buf + '"')
# save this server version string for later
- self.remote_version = buffer
+ self.remote_version = buf
# pull off any attached comment
comment = ''
- i = string.find(buffer, ' ')
+ i = string.find(buf, ' ')
if i >= 0:
- comment = buffer[i+1:]
- buffer = buffer[:i]
+ comment = buf[i+1:]
+ buf = buf[:i]
# parse out version string and make sure it matches
- segs = buffer.split('-', 2)
+ segs = buf.split('-', 2)
if len(segs) < 3:
raise SSHException('Invalid SSH banner')
version = segs[1]
@@ -1612,7 +1813,7 @@ class Transport (threading.Thread):
if not self.packetizer.need_rekey():
self.in_kex = False
# we always expect to receive NEWKEYS now
- self.expected_packet = MSG_NEWKEYS
+ self._expect_packet(MSG_NEWKEYS)
def _auth_trigger(self):
self.authenticated = True
@@ -1661,7 +1862,22 @@ class Transport (threading.Thread):
kind = m.get_string()
self._log(DEBUG, 'Received global request "%s"' % kind)
want_reply = m.get_boolean()
- ok = self.server_object.check_global_request(kind, m)
+ if not self.server_mode:
+ self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind)
+ ok = False
+ elif kind == 'tcpip-forward':
+ address = m.get_string()
+ port = m.get_int()
+ ok = self.server_object.check_port_forward_request(address, port)
+ if ok != False:
+ ok = (ok,)
+ elif kind == 'cancel-tcpip-forward':
+ address = m.get_string()
+ port = m.get_int()
+ self.server_object.cancel_port_forward_request(address, port)
+ ok = True
+ else:
+ ok = self.server_object.check_global_request(kind, m)
extra = ()
if type(ok) is tuple:
extra = ok
@@ -1692,15 +1908,15 @@ class Transport (threading.Thread):
server_chanid = m.get_int()
server_window_size = m.get_int()
server_max_packet_size = m.get_int()
- if not self.channels.has_key(chanid):
+ chan = self._channels.get(chanid)
+ if chan is None:
self._log(WARNING, 'Success for unrequested channel! [??]')
return
self.lock.acquire()
try:
- chan = self.channels[chanid]
chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
self._log(INFO, 'Secsh channel %d opened.' % chanid)
- if self.channel_events.has_key(chanid):
+ if chanid in self.channel_events:
self.channel_events[chanid].set()
del self.channel_events[chanid]
finally:
@@ -1712,16 +1928,14 @@ class Transport (threading.Thread):
reason = m.get_int()
reason_str = m.get_string()
lang = m.get_string()
- if CONNECTION_FAILED_CODE.has_key(reason):
- reason_text = CONNECTION_FAILED_CODE[reason]
- else:
- reason_text = '(unknown code)'
+ reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
+ self.lock.acquire()
try:
- self.lock.aquire()
- if self.channels.has_key(chanid):
- del self.channels[chanid]
- if self.channel_events.has_key(chanid):
+ self.saved_exception = ChannelException(reason, reason_text)
+ if chanid in self.channel_events:
+ self._channels.delete(chanid)
+ if chanid in self.channel_events:
self.channel_events[chanid].set()
del self.channel_events[chanid]
finally:
@@ -1734,21 +1948,47 @@ class Transport (threading.Thread):
initial_window_size = m.get_int()
max_packet_size = m.get_int()
reject = False
- if not self.server_mode:
+ if (kind == 'x11') and (self._x11_handler is not None):
+ origin_addr = m.get_string()
+ origin_port = m.get_int()
+ self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port))
+ self.lock.acquire()
+ try:
+ my_chanid = self._next_channel()
+ finally:
+ self.lock.release()
+ elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None):
+ server_addr = m.get_string()
+ server_port = m.get_int()
+ origin_addr = m.get_string()
+ origin_port = m.get_int()
+ self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port))
+ self.lock.acquire()
+ try:
+ my_chanid = self._next_channel()
+ finally:
+ self.lock.release()
+ elif not self.server_mode:
self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
reject = True
reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
else:
self.lock.acquire()
try:
- my_chanid = self.channel_counter
- while self.channels.has_key(my_chanid):
- self.channel_counter = (self.channel_counter + 1) & 0xffffff
- my_chanid = self.channel_counter
- self.channel_counter = (self.channel_counter + 1) & 0xffffff
+ my_chanid = self._next_channel()
finally:
self.lock.release()
- reason = self.server_object.check_channel_request(kind, my_chanid)
+ if kind == 'direct-tcpip':
+ # handle direct-tcpip requests comming from the client
+ dest_addr = m.get_string()
+ dest_port = m.get_int()
+ origin_addr = m.get_string()
+ origin_port = m.get_int()
+ reason = self.server_object.check_channel_direct_tcpip_request(
+ my_chanid, (origin_addr, origin_port),
+ (dest_addr, dest_port))
+ else:
+ reason = self.server_object.check_channel_request(kind, my_chanid)
if reason != OPEN_SUCCEEDED:
self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind)
reject = True
@@ -1761,10 +2001,11 @@ class Transport (threading.Thread):
msg.add_string('en')
self._send_message(msg)
return
+
chan = Channel(my_chanid)
+ self.lock.acquire()
try:
- self.lock.acquire()
- self.channels[my_chanid] = chan
+ self._channels.put(my_chanid, chan)
self.channels_seen[my_chanid] = True
chan._set_transport(self)
chan._set_window(self.window_size, self.max_packet_size)
@@ -1778,13 +2019,14 @@ class Transport (threading.Thread):
m.add_int(self.window_size)
m.add_int(self.max_packet_size)
self._send_message(m)
- self._log(INFO, 'Secsh channel %d opened.' % my_chanid)
- try:
- self.lock.acquire()
- self.server_accepts.append(chan)
- self.server_accept_cv.notify()
- finally:
- self.lock.release()
+ self._log(INFO, 'Secsh channel %d (%s) opened.', my_chanid, kind)
+ if kind == 'x11':
+ self._x11_handler(chan, (origin_addr, origin_port))
+ elif kind == 'forwarded-tcpip':
+ chan.origin_addr = (origin_addr, origin_port)
+ self._tcp_handler(chan, (origin_addr, origin_port), (server_addr, server_port))
+ else:
+ self._queue_incoming_channel(chan)
def _parse_debug(self, m):
always_display = m.get_boolean()
@@ -1795,7 +2037,7 @@ class Transport (threading.Thread):
def _get_subsystem_handler(self, name):
try:
self.lock.acquire()
- if not self.subsystem_table.has_key(name):
+ if name not in self.subsystem_table:
return (None, [], {})
return self.subsystem_table[name]
finally:
diff --git a/paramiko/util.py b/paramiko/util.py
index abab825..8abdc0c 100644
--- a/paramiko/util.py
+++ b/paramiko/util.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -22,13 +22,14 @@ Useful functions used by the rest of paramiko.
from __future__ import generators
-import fnmatch
+from binascii import hexlify, unhexlify
import sys
import struct
import traceback
import threading
from paramiko.common import *
+from paramiko.config import SSHConfig
# Change by RogerB - python < 2.3 doesn't have enumerate so we implement it
@@ -115,12 +116,10 @@ def format_binary_line(data):
return '%-50s %s' % (left, right)
def hexify(s):
- "turn a string into a hex sequence"
- return ''.join(['%02X' % ord(c) for c in s])
+ return hexlify(s).upper()
def unhexify(s):
- "turn a hex sequence back into a string"
- return ''.join([chr(int(s[i:i+2], 16)) for i in range(0, len(s), 2)])
+ return unhexlify(s)
def safe_string(s):
out = ''
@@ -168,12 +167,12 @@ def generate_key_bytes(hashclass, salt, key, nbytes):
if len(salt) > 8:
salt = salt[:8]
while nbytes > 0:
- hash = hashclass.new()
+ hash_obj = hashclass.new()
if len(digest) > 0:
- hash.update(digest)
- hash.update(key)
- hash.update(salt)
- digest = hash.digest()
+ hash_obj.update(digest)
+ hash_obj.update(key)
+ hash_obj.update(salt)
+ digest = hash_obj.digest()
size = min(nbytes, len(digest))
keydata += digest[:size]
nbytes -= size
@@ -189,117 +188,29 @@ def load_host_keys(filename):
This type of file unfortunately doesn't exist on Windows, but on posix,
it will usually be stored in C{os.path.expanduser("~/.ssh/known_hosts")}.
+ Since 1.5.3, this is just a wrapper around L{HostKeys}.
+
@param filename: name of the file to read host keys from
@type filename: str
@return: dict of host keys, indexed by hostname and then keytype
@rtype: dict(hostname, dict(keytype, L{PKey <paramiko.pkey.PKey>}))
"""
- import base64
- from rsakey import RSAKey
- from dsskey import DSSKey
-
- keys = {}
- f = file(filename, 'r')
- for line in f:
- line = line.strip()
- if (len(line) == 0) or (line[0] == '#'):
- continue
- keylist = line.split(' ')
- if len(keylist) != 3:
- continue
- hostlist, keytype, key = keylist
- hosts = hostlist.split(',')
- for host in hosts:
- if not keys.has_key(host):
- keys[host] = {}
- if keytype == 'ssh-rsa':
- keys[host][keytype] = RSAKey(data=base64.decodestring(key))
- elif keytype == 'ssh-dss':
- keys[host][keytype] = DSSKey(data=base64.decodestring(key))
- f.close()
- return keys
+ from paramiko.hostkeys import HostKeys
+ return HostKeys(filename)
def parse_ssh_config(file_obj):
"""
- Parse a config file of the format used by OpenSSH, and return an object
- that can be used to make queries to L{lookup_ssh_host_config}. The
- format is described in OpenSSH's C{ssh_config} man page. This method is
- provided primarily as a convenience to posix users (since the OpenSSH
- format is a de-facto standard on posix) but should work fine on Windows
- too.
-
- The return value is currently a list of dictionaries, each containing
- host-specific configuration, but this is considered an implementation
- detail and may be subject to change in later versions.
-
- @param file_obj: a file-like object to read the config file from
- @type file_obj: file
- @return: opaque configuration object
- @rtype: object
+ Provided only as a backward-compatible wrapper around L{SSHConfig}.
"""
- ret = []
- config = { 'host': '*' }
- ret.append(config)
-
- for line in file_obj:
- line = line.rstrip('\n').lstrip()
- if (line == '') or (line[0] == '#'):
- continue
- if '=' in line:
- key, value = line.split('=', 1)
- key = key.strip().lower()
- else:
- # find first whitespace, and split there
- i = 0
- while (i < len(line)) and not line[i].isspace():
- i += 1
- if i == len(line):
- raise Exception('Unparsable line: %r' % line)
- key = line[:i].lower()
- value = line[i:].lstrip()
-
- if key == 'host':
- # do we have a pre-existing host config to append to?
- matches = [c for c in ret if c['host'] == value]
- if len(matches) > 0:
- config = matches[0]
- else:
- config = { 'host': value }
- ret.append(config)
- else:
- config[key] = value
-
- return ret
+ config = SSHConfig()
+ config.parse(file_obj)
+ return config
def lookup_ssh_host_config(hostname, config):
"""
- Return a dict of config options for a given hostname. The C{config} object
- must come from L{parse_ssh_config}.
-
- The host-matching rules of OpenSSH's C{ssh_config} man page are used, which
- means that all configuration options from matching host specifications are
- merged, with more specific hostmasks taking precedence. In other words, if
- C{"Port"} is set under C{"Host *"} and also C{"Host *.example.com"}, and
- the lookup is for C{"ssh.example.com"}, then the port entry for
- C{"Host *.example.com"} will win out.
-
- The keys in the returned dict are all normalized to lowercase (look for
- C{"port"}, not C{"Port"}. No other processing is done to the keys or
- values.
-
- @param hostname: the hostname to lookup
- @type hostname: str
- @param config: the config object to search
- @type config: object
+ Provided only as a backward-compatible wrapper around L{SSHConfig}.
"""
- matches = [x for x in config if fnmatch.fnmatch(hostname, x['host'])]
- # sort in order of shortest match (usually '*') to longest
- matches.sort(key=lambda x: len(x['host']))
- ret = {}
- for m in matches:
- ret.update(m)
- del ret['host']
- return ret
+ return config.lookup(hostname)
def mod_inverse(x, m):
# it's crazy how small python can make this function.
@@ -355,3 +266,5 @@ def get_logger(name):
l = logging.getLogger(name)
l.addFilter(_pfilter)
return l
+
+
diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py
new file mode 100644
index 0000000..787032b
--- /dev/null
+++ b/paramiko/win_pageant.py
@@ -0,0 +1,148 @@
+# Copyright (C) 2005 John Arbash-Meinel <john@arbash-meinel.com>
+# Modified up by: Todd Whiteman <ToddW@ActiveState.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
+
+"""
+Functions for communicating with Pageant, the basic windows ssh agent program.
+"""
+
+import os
+import struct
+import tempfile
+import mmap
+import array
+
+# if you're on windows, you should have one of these, i guess?
+# ctypes is part of standard library since Python 2.5
+_has_win32all = False
+_has_ctypes = False
+try:
+ # win32gui is preferred over win32ui to avoid MFC dependencies
+ import win32gui
+ _has_win32all = True
+except ImportError:
+ try:
+ import ctypes
+ _has_ctypes = True
+ except ImportError:
+ pass
+
+
+_AGENT_COPYDATA_ID = 0x804e50ba
+_AGENT_MAX_MSGLEN = 8192
+# Note: The WM_COPYDATA value is pulled from win32con, as a workaround
+# so we do not have to import this huge library just for this one variable.
+win32con_WM_COPYDATA = 74
+
+
+def _get_pageant_window_object():
+ if _has_win32all:
+ try:
+ hwnd = win32gui.FindWindow('Pageant', 'Pageant')
+ return hwnd
+ except win32gui.error:
+ pass
+ elif _has_ctypes:
+ # Return 0 if there is no Pageant window.
+ return ctypes.windll.user32.FindWindowA('Pageant', 'Pageant')
+ return None
+
+
+def can_talk_to_agent():
+ """
+ Check to see if there is a "Pageant" agent we can talk to.
+
+ This checks both if we have the required libraries (win32all or ctypes)
+ and if there is a Pageant currently running.
+ """
+ if (_has_win32all or _has_ctypes) and _get_pageant_window_object():
+ return True
+ return False
+
+
+def _query_pageant(msg):
+ hwnd = _get_pageant_window_object()
+ if not hwnd:
+ # Raise a failure to connect exception, pageant isn't running anymore!
+ return None
+
+ # Write our pageant request string into the file (pageant will read this to determine what to do)
+ filename = tempfile.mktemp('.pag')
+ map_filename = os.path.basename(filename)
+
+ f = open(filename, 'w+b')
+ f.write(msg )
+ # Ensure the rest of the file is empty, otherwise pageant will read this
+ f.write('\0' * (_AGENT_MAX_MSGLEN - len(msg)))
+ # Create the shared file map that pageant will use to read from
+ pymap = mmap.mmap(f.fileno(), _AGENT_MAX_MSGLEN, tagname=map_filename, access=mmap.ACCESS_WRITE)
+ try:
+ # Create an array buffer containing the mapped filename
+ char_buffer = array.array("c", map_filename + '\0')
+ char_buffer_address, char_buffer_size = char_buffer.buffer_info()
+ # Create a string to use for the SendMessage function call
+ cds = struct.pack("LLP", _AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address)
+
+ if _has_win32all:
+ # win32gui.SendMessage should also allow the same pattern as
+ # ctypes, but let's keep it like this for now...
+ response = win32gui.SendMessage(hwnd, win32con_WM_COPYDATA, len(cds), cds)
+ elif _has_ctypes:
+ _buf = array.array('B', cds)
+ _addr, _size = _buf.buffer_info()
+ response = ctypes.windll.user32.SendMessageA(hwnd, win32con_WM_COPYDATA, _size, _addr)
+ else:
+ response = 0
+
+ if response > 0:
+ datalen = pymap.read(4)
+ retlen = struct.unpack('>I', datalen)[0]
+ return datalen + pymap.read(retlen)
+ return None
+ finally:
+ pymap.close()
+ f.close()
+ # Remove the file, it was temporary only
+ os.unlink(filename)
+
+
+class PageantConnection (object):
+ """
+ Mock "connection" to an agent which roughly approximates the behavior of
+ a unix local-domain socket (as used by Agent). Requests are sent to the
+ pageant daemon via special Windows magick, and responses are buffered back
+ for subsequent reads.
+ """
+
+ def __init__(self):
+ self._response = None
+
+ def send(self, data):
+ self._response = _query_pageant(data)
+
+ def recv(self, n):
+ if self._response is None:
+ return ''
+ ret = self._response[:n]
+ self._response = self._response[n:]
+ if self._response == '':
+ self._response = None
+ return ret
+
+ def close(self):
+ pass