diff options
author | Jeremy T. Bouse <jbouse@debian.org> | 2013-05-25 00:04:32 -0400 |
---|---|---|
committer | Jeremy T. Bouse <jbouse@debian.org> | 2013-05-25 00:04:32 -0400 |
commit | 1a716ed46d1d556d4ba6798608ab498320acd886 (patch) | |
tree | dbcb23de26387e312f7ea09085330eca90e15853 /paramiko | |
parent | a88b8c8c0f591a3bfa8d7984343a27815184f495 (diff) | |
download | python-paramiko-1a716ed46d1d556d4ba6798608ab498320acd886.tar python-paramiko-1a716ed46d1d556d4ba6798608ab498320acd886.tar.gz |
Imported Upstream version 1.10.1upstream/1.10.1
Diffstat (limited to 'paramiko')
-rw-r--r-- | paramiko/__init__.py | 29 | ||||
-rw-r--r-- | paramiko/agent.py | 333 | ||||
-rw-r--r-- | paramiko/channel.py | 54 | ||||
-rw-r--r-- | paramiko/client.py | 165 | ||||
-rw-r--r-- | paramiko/common.py | 3 | ||||
-rw-r--r-- | paramiko/config.py | 198 | ||||
-rw-r--r-- | paramiko/file.py | 4 | ||||
-rw-r--r-- | paramiko/hostkeys.py | 24 | ||||
-rw-r--r-- | paramiko/message.py | 3 | ||||
-rw-r--r-- | paramiko/packet.py | 52 | ||||
-rw-r--r-- | paramiko/proxy.py | 91 | ||||
-rw-r--r-- | paramiko/server.py | 18 | ||||
-rw-r--r-- | paramiko/sftp_client.py | 138 | ||||
-rw-r--r-- | paramiko/sftp_file.py | 25 | ||||
-rw-r--r-- | paramiko/ssh_exception.py | 17 | ||||
-rw-r--r-- | paramiko/transport.py | 238 | ||||
-rw-r--r-- | paramiko/util.py | 9 | ||||
-rw-r--r-- | paramiko/win_pageant.py | 28 |
18 files changed, 1091 insertions, 338 deletions
diff --git a/paramiko/__init__.py b/paramiko/__init__.py index 96b5943..099314e 100644 --- a/paramiko/__init__.py +++ b/paramiko/__init__.py @@ -18,7 +18,7 @@ """ I{Paramiko} (a combination of the esperanto words for "paranoid" and "friend") -is a module for python 2.3 or greater that implements the SSH2 protocol for +is a module for python 2.5 or greater that implements the SSH2 protocol for secure (encrypted and authenticated) connections to remote machines. Unlike SSL (aka TLS), the SSH2 protocol does not require hierarchical certificates signed by a powerful central authority. You may know SSH2 as the protocol that @@ -45,24 +45,17 @@ receive data over the encrypted session. Paramiko is written entirely in python (no C or platform-dependent code) and is released under the GNU Lesser General Public License (LGPL). -Website: U{http://www.lag.net/paramiko/} - -@version: 1.7.7.1 (George) -@author: Robey Pointer -@contact: robeypointer@gmail.com -@license: GNU Lesser General Public License (LGPL) +Website: U{https://github.com/paramiko/paramiko/} """ import sys -if sys.version_info < (2, 2): - raise RuntimeError('You need python 2.2 for this module.') +if sys.version_info < (2, 5): + raise RuntimeError('You need python 2.5+ for this module.') -__author__ = "Robey Pointer <robeypointer@gmail.com>" -__date__ = "21 May 2011" -__version__ = "1.7.7.1 (George)" -__version_info__ = (1, 7, 7, 1) +__author__ = "Jeff Forcier <jeff@bitprophet.org>" +__version__ = "1.10.1" __license__ = "GNU Lesser General Public License (LGPL)" @@ -72,7 +65,7 @@ from auth_handler import AuthHandler from channel import Channel, ChannelFile from ssh_exception import SSHException, PasswordRequiredException, \ BadAuthenticationType, ChannelException, BadHostKeyException, \ - AuthenticationException + AuthenticationException, ProxyCommandFailure from server import ServerInterface, SubsystemHandler, InteractiveQuery from rsakey import RSAKey from dsskey import DSSKey @@ -90,6 +83,7 @@ from agent import Agent, AgentKey from pkey import PKey from hostkeys import HostKeys from config import SSHConfig +from proxy import ProxyCommand # fix module names for epydoc for c in locals().values(): @@ -105,6 +99,8 @@ from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \ from sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \ SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED +from common import io_sleep + __all__ = [ 'Transport', 'SSHClient', 'MissingHostKeyPolicy', @@ -124,6 +120,8 @@ __all__ = [ 'Transport', 'BadAuthenticationType', 'ChannelException', 'BadHostKeyException', + 'ProxyCommand', + 'ProxyCommandFailure', 'SFTP', 'SFTPFile', 'SFTPHandle', @@ -138,4 +136,5 @@ __all__ = [ 'Transport', 'AgentKey', 'HostKeys', 'SSHConfig', - 'util' ] + 'util', + 'io_sleep' ] diff --git a/paramiko/agent.py b/paramiko/agent.py index 3bb9426..1dd3063 100644 --- a/paramiko/agent.py +++ b/paramiko/agent.py @@ -24,39 +24,308 @@ import os import socket import struct import sys +import threading +import time +import tempfile +import stat +from select import select from paramiko.ssh_exception import SSHException from paramiko.message import Message from paramiko.pkey import PKey - +from paramiko.channel import Channel +from paramiko.common import io_sleep +from paramiko.util import retry_on_signal SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \ SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15) +class AgentSSH(object): + """ + Client interface for using private keys from an SSH agent running on the + local machine. If an SSH agent is running, this class can be used to + connect to it and retreive L{PKey} objects which can be used when + attempting to authenticate to remote SSH servers. + + Because the SSH agent protocol uses environment variables and unix-domain + sockets, this probably doesn't work on Windows. It does work on most + posix platforms though (Linux and MacOS X, for example). + """ + def __init__(self): + self._conn = None + self._keys = () + + def get_keys(self): + """ + Return the list of keys available through the SSH agent, if any. If + no SSH agent was running (or it couldn't be contacted), an empty list + will be returned. + + @return: a list of keys available on the SSH agent + @rtype: tuple of L{AgentKey} + """ + return self._keys + + def _connect(self, conn): + self._conn = conn + ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES)) + if ptype != SSH2_AGENT_IDENTITIES_ANSWER: + raise SSHException('could not get keys from ssh-agent') + keys = [] + for i in range(result.get_int()): + keys.append(AgentKey(self, result.get_string())) + result.get_string() + self._keys = tuple(keys) + + def _close(self): + #self._conn.close() + self._conn = None + self._keys = () + + def _send_message(self, msg): + msg = str(msg) + self._conn.send(struct.pack('>I', len(msg)) + msg) + l = self._read_all(4) + msg = Message(self._read_all(struct.unpack('>I', l)[0])) + return ord(msg.get_byte()), msg + + def _read_all(self, wanted): + result = self._conn.recv(wanted) + while len(result) < wanted: + if len(result) == 0: + raise SSHException('lost ssh-agent') + extra = self._conn.recv(wanted - len(result)) + if len(extra) == 0: + raise SSHException('lost ssh-agent') + result += extra + return result + +class AgentProxyThread(threading.Thread): + """ Class in charge of communication between two chan """ + def __init__(self, agent): + threading.Thread.__init__(self, target=self.run) + self._agent = agent + self._exit = False + + def run(self): + try: + (r,addr) = self.get_connection() + self.__inr = r + self.__addr = addr + self._agent.connect() + self._communicate() + except: + #XXX Not sure what to do here ... raise or pass ? + raise + + def _communicate(self): + import fcntl + oldflags = fcntl.fcntl(self.__inr, fcntl.F_GETFL) + fcntl.fcntl(self.__inr, fcntl.F_SETFL, oldflags | os.O_NONBLOCK) + while not self._exit: + events = select([self._agent._conn, self.__inr], [], [], 0.5) + for fd in events[0]: + if self._agent._conn == fd: + data = self._agent._conn.recv(512) + if len(data) != 0: + self.__inr.send(data) + else: + self._close() + break + elif self.__inr == fd: + data = self.__inr.recv(512) + if len(data) != 0: + self._agent._conn.send(data) + else: + self._close() + break + time.sleep(io_sleep) + + def _close(self): + self._exit = True + self.__inr.close() + self._agent._conn.close() + +class AgentLocalProxy(AgentProxyThread): + """ + Class to be used when wanting to ask a local SSH Agent being + asked from a remote fake agent (so use a unix socket for ex.) + """ + def __init__(self, agent): + AgentProxyThread.__init__(self, agent) + + def get_connection(self): + """ Return a pair of socket object and string address + May Block ! + """ + conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + conn.bind(self._agent._get_filename()) + conn.listen(1) + (r,addr) = conn.accept() + return (r, addr) + except: + raise + return None -class Agent: +class AgentRemoteProxy(AgentProxyThread): + """ + Class to be used when wanting to ask a remote SSH Agent + """ + def __init__(self, agent, chan): + AgentProxyThread.__init__(self, agent) + self.__chan = chan + + def get_connection(self): + """ + Class to be used when wanting to ask a local SSH Agent being + asked from a remote fake agent (so use a unix socket for ex.) + """ + return (self.__chan, None) + +class AgentClientProxy(object): + """ + Class proxying request as a client: + -> client ask for a request_forward_agent() + -> server creates a proxy and a fake SSH Agent + -> server ask for establishing a connection when needed, + calling the forward_agent_handler at client side. + -> the forward_agent_handler launch a thread for connecting + the remote fake agent and the local agent + -> Communication occurs ... + """ + def __init__(self, chanRemote): + self._conn = None + self.__chanR = chanRemote + self.thread = AgentRemoteProxy(self, chanRemote) + self.thread.start() + + def __del__(self): + self.close() + + def connect(self): + """ + Method automatically called by the run() method of the AgentProxyThread + """ + if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): + conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + retry_on_signal(lambda: conn.connect(os.environ['SSH_AUTH_SOCK'])) + except: + # probably a dangling env var: the ssh agent is gone + return + elif sys.platform == 'win32': + import win_pageant + if win_pageant.can_talk_to_agent(): + conn = win_pageant.PageantConnection() + else: + return + else: + # no agent support + return + self._conn = conn + + def close(self): + """ + Close the current connection and terminate the agent + Should be called manually + """ + if hasattr(self, "thread"): + self.thread._exit = True + self.thread.join(1000) + if self._conn is not None: + self._conn.close() + +class AgentServerProxy(AgentSSH): + """ + @param t : transport used for the Forward for SSH Agent communication + + @raise SSHException: mostly if we lost the agent + """ + def __init__(self, t): + AgentSSH.__init__(self) + self.__t = t + self._dir = tempfile.mkdtemp('sshproxy') + os.chmod(self._dir, stat.S_IRWXU) + self._file = self._dir + '/sshproxy.ssh' + self.thread = AgentLocalProxy(self) + self.thread.start() + + def __del__(self): + self.close() + + def connect(self): + conn_sock = self.__t.open_forward_agent_channel() + if conn_sock is None: + raise SSHException('lost ssh-agent') + conn_sock.set_name('auth-agent') + self._connect(conn_sock) + + def close(self): + """ + Terminate the agent, clean the files, close connections + Should be called manually + """ + os.remove(self._file) + os.rmdir(self._dir) + self.thread._exit = True + self.thread.join(1000) + self._close() + + def get_env(self): + """ + Helper for the environnement under unix + + @return: the SSH_AUTH_SOCK Environnement variables + @rtype: dict + """ + env = {} + env['SSH_AUTH_SOCK'] = self._get_filename() + return env + + def _get_filename(self): + return self._file + +class AgentRequestHandler(object): + def __init__(self, chanClient): + self._conn = None + self.__chanC = chanClient + chanClient.request_forward_agent(self._forward_agent_handler) + self.__clientProxys = [] + + def _forward_agent_handler(self, chanRemote): + self.__clientProxys.append(AgentClientProxy(chanRemote)) + + def __del__(self): + self.close() + + def close(self): + for p in self.__clientProxys: + p.close() + +class Agent(AgentSSH): """ Client interface for using private keys from an SSH agent running on the local machine. If an SSH agent is running, this class can be used to connect to it and retreive L{PKey} objects which can be used when attempting to authenticate to remote SSH servers. - + Because the SSH agent protocol uses environment variables and unix-domain sockets, this probably doesn't work on Windows. It does work on most posix platforms though (Linux and MacOS X, for example). """ - + def __init__(self): """ Open a session with the local machine's SSH agent, if one is running. If no agent is running, initialization will succeed, but L{get_keys} will return an empty tuple. - + @raise SSHException: if an SSH agent is found, but speaks an incompatible protocol """ - self.conn = None - self.keys = () + AgentSSH.__init__(self) + if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: @@ -64,64 +333,22 @@ class Agent: except: # probably a dangling env var: the ssh agent is gone return - self.conn = conn elif sys.platform == 'win32': import win_pageant if win_pageant.can_talk_to_agent(): - self.conn = win_pageant.PageantConnection() + conn = win_pageant.PageantConnection() else: return else: # no agent support return - - ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES)) - if ptype != SSH2_AGENT_IDENTITIES_ANSWER: - raise SSHException('could not get keys from ssh-agent') - keys = [] - for i in range(result.get_int()): - keys.append(AgentKey(self, result.get_string())) - result.get_string() - self.keys = tuple(keys) + self._connect(conn) def close(self): """ Close the SSH agent connection. """ - if self.conn is not None: - self.conn.close() - self.conn = None - self.keys = () - - def get_keys(self): - """ - Return the list of keys available through the SSH agent, if any. If - no SSH agent was running (or it couldn't be contacted), an empty list - will be returned. - - @return: a list of keys available on the SSH agent - @rtype: tuple of L{AgentKey} - """ - return self.keys - - def _send_message(self, msg): - msg = str(msg) - self.conn.send(struct.pack('>I', len(msg)) + msg) - l = self._read_all(4) - msg = Message(self._read_all(struct.unpack('>I', l)[0])) - return ord(msg.get_byte()), msg - - def _read_all(self, wanted): - result = self.conn.recv(wanted) - while len(result) < wanted: - if len(result) == 0: - raise SSHException('lost ssh-agent') - extra = self.conn.recv(wanted - len(result)) - if len(extra) == 0: - raise SSHException('lost ssh-agent') - result += extra - return result - + self._close() class AgentKey(PKey): """ @@ -129,7 +356,7 @@ class AgentKey(PKey): authenticating to a remote server (signing). Most other key operations work as expected. """ - + def __init__(self, agent, blob): self.agent = agent self.blob = blob diff --git a/paramiko/channel.py b/paramiko/channel.py index 6d895fe..0c603c6 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -122,7 +122,8 @@ class Channel (object): out += '>' return out - def get_pty(self, term='vt100', width=80, height=24): + def get_pty(self, term='vt100', width=80, height=24, width_pixels=0, + height_pixels=0): """ Request a pseudo-terminal from the server. This is usually used right after creating a client channel, to ask the server to provide some @@ -136,6 +137,10 @@ class Channel (object): @type width: int @param height: height (in characters) of the terminal screen @type height: int + @param width_pixels: width (in pixels) of the terminal screen + @type width_pixels: int + @param height_pixels: height (in pixels) of the terminal screen + @type height_pixels: int @raise SSHException: if the request was rejected or the channel was closed @@ -150,8 +155,8 @@ class Channel (object): m.add_string(term) m.add_int(width) m.add_int(height) - # pixel height, width (usually useless) - m.add_int(0).add_int(0) + m.add_int(width_pixels) + m.add_int(height_pixels) m.add_string('') self._event_pending() self.transport._send_user_message(m) @@ -239,7 +244,7 @@ class Channel (object): self.transport._send_user_message(m) self._wait_for_event() - def resize_pty(self, width=80, height=24): + def resize_pty(self, width=80, height=24, width_pixels=0, height_pixels=0): """ Resize the pseudo-terminal. This can be used to change the width and height of the terminal emulation created in a previous L{get_pty} call. @@ -248,6 +253,10 @@ class Channel (object): @type width: int @param height: new height (in characters) of the terminal screen @type height: int + @param width_pixels: new width (in pixels) of the terminal screen + @type width_pixels: int + @param height_pixels: new height (in pixels) of the terminal screen + @type height_pixels: int @raise SSHException: if the request was rejected or the channel was closed @@ -258,13 +267,12 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_int(self.remote_chanid) m.add_string('window-change') - m.add_boolean(True) + m.add_boolean(False) m.add_int(width) m.add_int(height) - m.add_int(0).add_int(0) - self._event_pending() + m.add_int(width_pixels) + m.add_int(height_pixels) self.transport._send_user_message(m) - self._wait_for_event() def exit_status_ready(self): """ @@ -381,6 +389,31 @@ class Channel (object): self.transport._set_x11_handler(handler) return auth_cookie + def request_forward_agent(self, handler): + """ + Request for a forward SSH Agent on this channel. + This is only valid for an ssh-agent from openssh !!! + + @param handler: a required handler to use for incoming SSH Agent connections + @type handler: function + + @return: if we are ok or not (at that time we always return ok) + @rtype: boolean + + @raise: SSHException in case of channel problem. + """ + if self.closed or self.eof_received or self.eof_sent or not self.active: + raise SSHException('Channel is not open') + + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('auth-agent-req@openssh.com') + m.add_boolean(False) + self.transport._send_user_message(m) + self.transport._set_forward_agent_handler(handler) + return True + def get_transport(self): """ Return the L{Transport} associated with this channel. @@ -1026,6 +1059,11 @@ class Channel (object): else: ok = server.check_channel_x11_request(self, single_connection, auth_proto, auth_cookie, screen_number) + elif key == 'auth-agent-req@openssh.com': + if server is None: + ok = False + else: + ok = server.check_channel_forward_agent_request(self) else: self._log(DEBUG, 'Unhandled channel request "%s"' % key) ok = False diff --git a/paramiko/client.py b/paramiko/client.py index 4a65477..5b71958 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -28,16 +28,16 @@ import warnings from paramiko.agent import Agent from paramiko.common import * +from paramiko.config import SSH_PORT from paramiko.dsskey import DSSKey from paramiko.hostkeys import HostKeys from paramiko.resource import ResourceManager from paramiko.rsakey import RSAKey from paramiko.ssh_exception import SSHException, BadHostKeyException from paramiko.transport import Transport +from paramiko.util import retry_on_signal -SSH_PORT = 22 - class MissingHostKeyPolicy (object): """ Interface for defining the policy that L{SSHClient} should use when the @@ -82,7 +82,7 @@ class RejectPolicy (MissingHostKeyPolicy): def missing_host_key(self, client, hostname, key): client._log(DEBUG, 'Rejecting %s host key for %s: %s' % (key.get_name(), hostname, hexlify(key.get_fingerprint()))) - raise SSHException('Unknown server %s' % hostname) + raise SSHException('Server %r not found in known_hosts' % hostname) class WarningPolicy (MissingHostKeyPolicy): @@ -228,7 +228,7 @@ class SSHClient (object): def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=None, key_filename=None, timeout=None, allow_agent=True, look_for_keys=True, - compress=False): + compress=False, sock=None): """ Connect to an SSH server and authenticate to it. The server's host key is checked against the system host keys (see L{load_system_host_keys}) @@ -271,6 +271,9 @@ class SSHClient (object): @type look_for_keys: bool @param compress: set to True to turn on compression @type compress: bool + @param sock: an open socket or socket-like object (such as a + L{Channel}) to use for communication to the target host + @type sock: socket @raise BadHostKeyException: if the server's host key could not be verified @@ -279,21 +282,23 @@ class SSHClient (object): establishing an SSH session @raise socket.error: if a socket error occurred while connecting """ - for (family, socktype, proto, canonname, sockaddr) in socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM): - if socktype == socket.SOCK_STREAM: - af = family - addr = sockaddr - break - else: - # some OS like AIX don't indicate SOCK_STREAM support, so just guess. :( - af, _, _, _, addr = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM) - sock = socket.socket(af, socket.SOCK_STREAM) - if timeout is not None: - try: - sock.settimeout(timeout) - except: - pass - sock.connect(addr) + if not sock: + for (family, socktype, proto, canonname, sockaddr) in socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM): + if socktype == socket.SOCK_STREAM: + af = family + addr = sockaddr + break + else: + # some OS like AIX don't indicate SOCK_STREAM support, so just guess. :( + af, _, _, _, addr = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM) + sock = socket.socket(af, socket.SOCK_STREAM) + if timeout is not None: + try: + sock.settimeout(timeout) + except: + pass + retry_on_signal(lambda: sock.connect(addr)) + t = self._transport = Transport(sock) t.use_compression(compress=compress) if self._log_channel is not None: @@ -344,7 +349,7 @@ class SSHClient (object): self._agent.close() self._agent = None - def exec_command(self, command, bufsize=-1): + def exec_command(self, command, bufsize=-1, timeout=None, get_pty=False): """ Execute a command on the SSH server. A new L{Channel} is opened and the requested command is executed. The command's input and output @@ -355,19 +360,25 @@ class SSHClient (object): @type command: str @param bufsize: interpreted the same way as by the built-in C{file()} function in python @type bufsize: int + @param timeout: set command's channel timeout. See L{Channel.settimeout}.settimeout + @type timeout: int @return: the stdin, stdout, and stderr of the executing command @rtype: tuple(L{ChannelFile}, L{ChannelFile}, L{ChannelFile}) @raise SSHException: if the server fails to execute the command """ chan = self._transport.open_session() + if(get_pty): + chan.get_pty() + chan.settimeout(timeout) chan.exec_command(command) stdin = chan.makefile('wb', bufsize) stdout = chan.makefile('rb', bufsize) stderr = chan.makefile_stderr('rb', bufsize) return stdin, stdout, stderr - def invoke_shell(self, term='vt100', width=80, height=24): + def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=0, + height_pixels=0): """ Start an interactive shell session on the SSH server. A new L{Channel} is opened and connected to a pseudo-terminal using the requested @@ -379,13 +390,17 @@ class SSHClient (object): @type width: int @param height: the height (in characters) of the terminal window @type height: int + @param width_pixels: the width (in pixels) of the terminal window + @type width_pixels: int + @param height_pixels: the height (in pixels) of the terminal window + @type height_pixels: int @return: a new channel connected to the remote shell @rtype: L{Channel} @raise SSHException: if the server fails to invoke a shell """ chan = self._transport.open_session() - chan.get_pty(term, width, height) + chan.get_pty(term, width, height, width_pixels, height_pixels) chan.invoke_shell() return chan @@ -418,68 +433,86 @@ class SSHClient (object): - Any "id_rsa" or "id_dsa" key discoverable in ~/.ssh/ (if allowed). - Plain username/password auth, if a password was given. - (The password might be needed to unlock a private key.) + (The password might be needed to unlock a private key, or for + two-factor authentication [for which it is required].) """ saved_exception = None + two_factor = False + allowed_types = [] if pkey is not None: try: self._log(DEBUG, 'Trying SSH key %s' % hexlify(pkey.get_fingerprint())) - self._transport.auth_publickey(username, pkey) - return + allowed_types = self._transport.auth_publickey(username, pkey) + two_factor = (allowed_types == ['password']) + if not two_factor: + return except SSHException, e: saved_exception = e - for key_filename in key_filenames: - for pkey_class in (RSAKey, DSSKey): - try: - key = pkey_class.from_private_key_file(key_filename, password) - self._log(DEBUG, 'Trying key %s from %s' % (hexlify(key.get_fingerprint()), key_filename)) - self._transport.auth_publickey(username, key) - return - except SSHException, e: - saved_exception = e - - if allow_agent: + if not two_factor: + for key_filename in key_filenames: + for pkey_class in (RSAKey, DSSKey): + try: + key = pkey_class.from_private_key_file(key_filename, password) + self._log(DEBUG, 'Trying key %s from %s' % (hexlify(key.get_fingerprint()), key_filename)) + self._transport.auth_publickey(username, key) + two_factor = (allowed_types == ['password']) + if not two_factor: + return + break + except SSHException, e: + saved_exception = e + + if not two_factor and allow_agent: if self._agent == None: self._agent = Agent() for key in self._agent.get_keys(): try: self._log(DEBUG, 'Trying SSH agent key %s' % hexlify(key.get_fingerprint())) - self._transport.auth_publickey(username, key) - return + # for 2-factor auth a successfully auth'd key will result in ['password'] + allowed_types = self._transport.auth_publickey(username, key) + two_factor = (allowed_types == ['password']) + if not two_factor: + return + break except SSHException, e: saved_exception = e - keyfiles = [] - rsa_key = os.path.expanduser('~/.ssh/id_rsa') - dsa_key = os.path.expanduser('~/.ssh/id_dsa') - if os.path.isfile(rsa_key): - keyfiles.append((RSAKey, rsa_key)) - if os.path.isfile(dsa_key): - keyfiles.append((DSSKey, dsa_key)) - # look in ~/ssh/ for windows users: - rsa_key = os.path.expanduser('~/ssh/id_rsa') - dsa_key = os.path.expanduser('~/ssh/id_dsa') - if os.path.isfile(rsa_key): - keyfiles.append((RSAKey, rsa_key)) - if os.path.isfile(dsa_key): - keyfiles.append((DSSKey, dsa_key)) - - if not look_for_keys: + if not two_factor: keyfiles = [] - - for pkey_class, filename in keyfiles: - try: - key = pkey_class.from_private_key_file(filename, password) - self._log(DEBUG, 'Trying discovered key %s in %s' % (hexlify(key.get_fingerprint()), filename)) - self._transport.auth_publickey(username, key) - return - except SSHException, e: - saved_exception = e - except IOError, e: - saved_exception = e + rsa_key = os.path.expanduser('~/.ssh/id_rsa') + dsa_key = os.path.expanduser('~/.ssh/id_dsa') + if os.path.isfile(rsa_key): + keyfiles.append((RSAKey, rsa_key)) + if os.path.isfile(dsa_key): + keyfiles.append((DSSKey, dsa_key)) + # look in ~/ssh/ for windows users: + rsa_key = os.path.expanduser('~/ssh/id_rsa') + dsa_key = os.path.expanduser('~/ssh/id_dsa') + if os.path.isfile(rsa_key): + keyfiles.append((RSAKey, rsa_key)) + if os.path.isfile(dsa_key): + keyfiles.append((DSSKey, dsa_key)) + + if not look_for_keys: + keyfiles = [] + + for pkey_class, filename in keyfiles: + try: + key = pkey_class.from_private_key_file(filename, password) + self._log(DEBUG, 'Trying discovered key %s in %s' % (hexlify(key.get_fingerprint()), filename)) + # for 2-factor auth a successfully auth'd key will result in ['password'] + allowed_types = self._transport.auth_publickey(username, key) + two_factor = (allowed_types == ['password']) + if not two_factor: + return + break + except SSHException, e: + saved_exception = e + except IOError, e: + saved_exception = e if password is not None: try: @@ -487,6 +520,8 @@ class SSHClient (object): return except SSHException, e: saved_exception = e + elif two_factor: + raise SSHException('Two-factor authentication requires a password') # if we got an auth-failed exception earlier, re-raise it if saved_exception is not None: diff --git a/paramiko/common.py b/paramiko/common.py index 3323f0a..25d5457 100644 --- a/paramiko/common.py +++ b/paramiko/common.py @@ -124,3 +124,6 @@ INFO = logging.INFO WARNING = logging.WARNING ERROR = logging.ERROR CRITICAL = logging.CRITICAL + +# Common IO/select/etc sleep period, in seconds +io_sleep = 0.01 diff --git a/paramiko/config.py b/paramiko/config.py index 2a2cbff..e41bae4 100644 --- a/paramiko/config.py +++ b/paramiko/config.py @@ -1,4 +1,5 @@ # Copyright (C) 2006-2007 Robey Pointer <robeypointer@gmail.com> +# Copyright (C) 2012 Olle Lundberg <geek@nerd.sh> # # This file is part of paramiko. # @@ -21,6 +22,57 @@ L{SSHConfig}. """ import fnmatch +import os +import re +import socket + +SSH_PORT = 22 +proxy_re = re.compile(r"^(proxycommand)\s*=*\s*(.*)", re.I) + + +class LazyFqdn(object): + """ + Returns the host's fqdn on request as string. + """ + + def __init__(self, config): + self.fqdn = None + self.config = config + + def __str__(self): + if self.fqdn is None: + # + # If the SSH config contains AddressFamily, use that when + # determining the local host's FQDN. Using socket.getfqdn() from + # the standard library is the most general solution, but can + # result in noticeable delays on some platforms when IPv6 is + # misconfigured or not available, as it calls getaddrinfo with no + # address family specified, so both IPv4 and IPv6 are checked. + # + + # Handle specific option + fqdn = None + address_family = self.config.get('addressfamily', 'any').lower() + if address_family != 'any': + family = socket.AF_INET if address_family == 'inet' \ + else socket.AF_INET6 + results = socket.getaddrinfo(host, + None, + family, + socket.SOCK_DGRAM, + socket.IPPROTO_IP, + socket.AI_CANONNAME) + for res in results: + af, socktype, proto, canonname, sa = res + if canonname and '.' in canonname: + fqdn = canonname + break + # Handle 'any' / unspecified + if fqdn is None: + fqdn = socket.getfqdn() + # Cache + self.fqdn = fqdn + return self.fqdn class SSHConfig (object): @@ -38,7 +90,7 @@ class SSHConfig (object): """ Create a new OpenSSH config object. """ - self._config = [ { 'host': '*' } ] + self._config = [] def parse(self, file_obj): """ @@ -47,14 +99,19 @@ class SSHConfig (object): @param file_obj: a file-like object to read the config file from @type file_obj: file """ - configs = [self._config[0]] + host = {"host": ['*'], "config": {}} for line in file_obj: line = line.rstrip('\n').lstrip() if (line == '') or (line[0] == '#'): continue if '=' in line: - key, value = line.split('=', 1) - key = key.strip().lower() + # Ensure ProxyCommand gets properly split + if line.lower().strip().startswith('proxycommand'): + match = proxy_re.match(line) + key, value = match.group(1).lower(), match.group(2) + else: + key, value = line.split('=', 1) + key = key.strip().lower() else: # find first whitespace, and split there i = 0 @@ -66,20 +123,20 @@ class SSHConfig (object): value = line[i:].lstrip() if key == 'host': - del configs[:] - # the value may be multiple hosts, space-delimited - for host in value.split(): - # do we have a pre-existing host config to append to? - matches = [c for c in self._config if c['host'] == host] - if len(matches) > 0: - configs.append(matches[0]) - else: - config = { 'host': host } - self._config.append(config) - configs.append(config) - else: - for config in configs: - config[key] = value + self._config.append(host) + value = value.split() + host = {key: value, 'config': {}} + #identityfile is a special case, since it is allowed to be + # specified multiple times and they should be tried in order + # of specification. + elif key == 'identityfile': + if key in host['config']: + host['config']['identityfile'].append(value) + else: + host['config']['identityfile'] = [value] + elif key not in host['config']: + host['config'].update({key: value}) + self._config.append(host) def lookup(self, hostname): """ @@ -94,17 +151,106 @@ class SSHConfig (object): will win out. The keys in the returned dict are all normalized to lowercase (look for - C{"port"}, not C{"Port"}. No other processing is done to the keys or - values. + C{"port"}, not C{"Port"}. The values are processed according to the + rules for substitution variable expansion in C{ssh_config}. @param hostname: the hostname to lookup @type hostname: str """ - matches = [x for x in self._config if fnmatch.fnmatch(hostname, x['host'])] - # sort in order of shortest match (usually '*') to longest - matches.sort(lambda x,y: cmp(len(x['host']), len(y['host']))) + + matches = [config for config in self._config if + self._allowed(hostname, config['host'])] + ret = {} - for m in matches: - ret.update(m) - del ret['host'] + for match in matches: + for key, value in match['config'].iteritems(): + if key not in ret: + # Create a copy of the original value, + # else it will reference the original list + # in self._config and update that value too + # when the extend() is being called. + ret[key] = value[:] + elif key == 'identityfile': + ret[key].extend(value) + ret = self._expand_variables(ret, hostname) return ret + + def _allowed(self, hostname, hosts): + match = False + for host in hosts: + if host.startswith('!') and fnmatch.fnmatch(hostname, host[1:]): + return False + elif fnmatch.fnmatch(hostname, host): + match = True + return match + + def _expand_variables(self, config, hostname): + """ + Return a dict of config options with expanded substitutions + for a given hostname. + + Please refer to man C{ssh_config} for the parameters that + are replaced. + + @param config: the config for the hostname + @type hostname: dict + @param hostname: the hostname that the config belongs to + @type hostname: str + """ + + if 'hostname' in config: + config['hostname'] = config['hostname'].replace('%h', hostname) + else: + config['hostname'] = hostname + + if 'port' in config: + port = config['port'] + else: + port = SSH_PORT + + user = os.getenv('USER') + if 'user' in config: + remoteuser = config['user'] + else: + remoteuser = user + + host = socket.gethostname().split('.')[0] + fqdn = LazyFqdn(config) + homedir = os.path.expanduser('~') + replacements = {'controlpath': + [ + ('%h', config['hostname']), + ('%l', fqdn), + ('%L', host), + ('%n', hostname), + ('%p', port), + ('%r', remoteuser), + ('%u', user) + ], + 'identityfile': + [ + ('~', homedir), + ('%d', homedir), + ('%h', config['hostname']), + ('%l', fqdn), + ('%u', user), + ('%r', remoteuser) + ], + 'proxycommand': + [ + ('%h', config['hostname']), + ('%p', port), + ('%r', remoteuser) + ] + } + + for k in config: + if k in replacements: + for find, replace in replacements[k]: + if isinstance(config[k], list): + for item in range(len(config[k])): + config[k][item] = config[k][item].\ + replace(find, str(replace)) + else: + config[k] = config[k].replace(find, str(replace)) + return config diff --git a/paramiko/file.py b/paramiko/file.py index d4aec8e..7e2904e 100644 --- a/paramiko/file.py +++ b/paramiko/file.py @@ -354,6 +354,10 @@ class BufferedFile (object): """ return self + @property + def closed(self): + return self._closed + ### overrides... diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py index 70ccf43..e739312 100644 --- a/paramiko/hostkeys.py +++ b/paramiko/hostkeys.py @@ -21,6 +21,7 @@ L{HostKeys} """ import base64 +import binascii from Crypto.Hash import SHA, HMAC import UserDict @@ -29,6 +30,14 @@ from paramiko.dsskey import DSSKey from paramiko.rsakey import RSAKey +class InvalidHostKey(Exception): + + def __init__(self, line, exc): + self.line = line + self.exc = exc + self.args = (line, exc) + + class HostKeyEntry: """ Representation of a line in an OpenSSH-style "known hosts" file. @@ -63,12 +72,15 @@ class HostKeyEntry: # Decide what kind of key we're looking at and create an object # to hold it accordingly. - if keytype == 'ssh-rsa': - key = RSAKey(data=base64.decodestring(key)) - elif keytype == 'ssh-dss': - key = DSSKey(data=base64.decodestring(key)) - else: - return None + try: + if keytype == 'ssh-rsa': + key = RSAKey(data=base64.decodestring(key)) + elif keytype == 'ssh-dss': + key = DSSKey(data=base64.decodestring(key)) + else: + return None + except binascii.Error, e: + raise InvalidHostKey(line, e) return cls(names, key) from_line = classmethod(from_line) diff --git a/paramiko/message.py b/paramiko/message.py index 366c43c..47acc34 100644 --- a/paramiko/message.py +++ b/paramiko/message.py @@ -110,7 +110,8 @@ class Message (object): @rtype: string """ b = self.packet.read(n) - if len(b) < n: + max_pad_size = 1<<20 # Limit padding to 1 MB + if len(b) < n and n < max_pad_size: return b + '\x00' * (n - len(b)) return b diff --git a/paramiko/packet.py b/paramiko/packet.py index 391c5d5..38a6d4b 100644 --- a/paramiko/packet.py +++ b/paramiko/packet.py @@ -29,7 +29,7 @@ import time from paramiko.common import * from paramiko import util -from paramiko.ssh_exception import SSHException +from paramiko.ssh_exception import SSHException, ProxyCommandFailure from paramiko.message import Message @@ -57,8 +57,11 @@ class Packetizer (object): # READ the secsh RFC's before raising these values. if anything, # they should probably be lower. - REKEY_PACKETS = pow(2, 30) - REKEY_BYTES = pow(2, 30) + REKEY_PACKETS = pow(2, 29) + REKEY_BYTES = pow(2, 29) + + REKEY_PACKETS_OVERFLOW_MAX = pow(2,29) # Allow receiving this many packets after a re-key request before terminating + REKEY_BYTES_OVERFLOW_MAX = pow(2,29) # Allow receiving this many bytes after a re-key request before terminating def __init__(self, socket): self.__socket = socket @@ -74,6 +77,7 @@ class Packetizer (object): self.__sent_packets = 0 self.__received_bytes = 0 self.__received_packets = 0 + self.__received_bytes_overflow = 0 self.__received_packets_overflow = 0 # current inbound/outbound ciphering: @@ -83,6 +87,7 @@ class Packetizer (object): self.__mac_size_in = 0 self.__block_engine_out = None self.__block_engine_in = None + self.__sdctr_out = False self.__mac_engine_out = None self.__mac_engine_in = None self.__mac_key_out = '' @@ -106,11 +111,12 @@ class Packetizer (object): """ self.__logger = log - def set_outbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key): + def set_outbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key, sdctr=False): """ Switch outbound data cipher. """ self.__block_engine_out = block_engine + self.__sdctr_out = sdctr self.__block_size_out = block_size self.__mac_engine_out = mac_engine self.__mac_size_out = mac_size @@ -134,6 +140,7 @@ class Packetizer (object): self.__mac_key_in = mac_key self.__received_bytes = 0 self.__received_packets = 0 + self.__received_bytes_overflow = 0 self.__received_packets_overflow = 0 # wait until the reset happens in both directions before clearing rekey flag self.__init_count |= 2 @@ -236,23 +243,25 @@ class Packetizer (object): def write_all(self, out): self.__keepalive_last = time.time() while len(out) > 0: - got_timeout = False + retry_write = False try: n = self.__socket.send(out) except socket.timeout: - got_timeout = True + retry_write = True except socket.error, e: if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN): - got_timeout = True + retry_write = True elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR): # syscall interrupted; try again - pass + retry_write = True else: n = -1 + except ProxyCommandFailure: + raise # so it doesn't get swallowed by the below catchall except Exception: # could be: (32, 'Broken pipe') n = -1 - if got_timeout: + if retry_write: n = 0 if self.__closed: n = -1 @@ -316,6 +325,7 @@ class Packetizer (object): # only ask once for rekeying self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' % (self.__sent_packets, self.__sent_bytes)) + self.__received_bytes_overflow = 0 self.__received_packets_overflow = 0 self._trigger_rekey() finally: @@ -368,19 +378,23 @@ class Packetizer (object): self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL # check for rekey - self.__received_bytes += packet_size + self.__mac_size_in + 4 + raw_packet_size = packet_size + self.__mac_size_in + 4 + self.__received_bytes += raw_packet_size self.__received_packets += 1 if self.__need_rekey: - # we've asked to rekey -- give them 20 packets to comply before + # we've asked to rekey -- give them some packets to comply before # dropping the connection + self.__received_bytes_overflow += raw_packet_size self.__received_packets_overflow += 1 - if self.__received_packets_overflow >= 20: + if (self.__received_packets_overflow >= self.REKEY_PACKETS_OVERFLOW_MAX) or \ + (self.__received_bytes_overflow >= self.REKEY_BYTES_OVERFLOW_MAX): raise SSHException('Remote transport is ignoring rekey requests') elif (self.__received_packets >= self.REKEY_PACKETS) or \ (self.__received_bytes >= self.REKEY_BYTES): # only ask once for rekeying self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' % (self.__received_packets, self.__received_bytes)) + self.__received_bytes_overflow = 0 self.__received_packets_overflow = 0 self._trigger_rekey() @@ -459,6 +473,12 @@ class Packetizer (object): break except socket.timeout: pass + except EnvironmentError, e: + if ((type(e.args) is tuple) and (len(e.args) > 0) and + (e.args[0] == errno.EINTR)): + pass + else: + raise if self.__closed: raise EOFError() now = time.time() @@ -472,12 +492,12 @@ class Packetizer (object): padding = 3 + bsize - ((len(payload) + 8) % bsize) packet = struct.pack('>IB', len(payload) + padding + 1, padding) packet += payload - if self.__block_engine_out is not None: - packet += rng.read(padding) - else: - # cute trick i caught openssh doing: if we're not encrypting, + if self.__sdctr_out or self.__block_engine_out is None: + # cute trick i caught openssh doing: if we're not encrypting or SDCTR mode (RFC4344), # don't waste random bytes for the padding packet += (chr(0) * padding) + else: + packet += rng.read(padding) return packet def _trigger_rekey(self): diff --git a/paramiko/proxy.py b/paramiko/proxy.py new file mode 100644 index 0000000..218b76e --- /dev/null +++ b/paramiko/proxy.py @@ -0,0 +1,91 @@ +# Copyright (C) 2012 Yipit, Inc <coders@yipit.com> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{ProxyCommand}. +""" + +import os +from shlex import split as shlsplit +import signal +from subprocess import Popen, PIPE + +from paramiko.ssh_exception import ProxyCommandFailure + + +class ProxyCommand(object): + """ + Wraps a subprocess running ProxyCommand-driven programs. + + This class implements a the socket-like interface needed by the + L{Transport} and L{Packetizer} classes. Using this class instead of a + regular socket makes it possible to talk with a Popen'd command that will + proxy traffic between the client and a server hosted in another machine. + """ + def __init__(self, command_line): + """ + Create a new CommandProxy instance. The instance created by this + class can be passed as an argument to the L{Transport} class. + + @param command_line: the command that should be executed and + used as the proxy. + @type command_line: str + """ + self.cmd = shlsplit(command_line) + self.process = Popen(self.cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE) + + def send(self, content): + """ + Write the content received from the SSH client to the standard + input of the forked command. + + @param content: string to be sent to the forked command + @type content: str + """ + try: + self.process.stdin.write(content) + except IOError, e: + # There was a problem with the child process. It probably + # died and we can't proceed. The best option here is to + # raise an exception informing the user that the informed + # ProxyCommand is not working. + raise BadProxyCommand(' '.join(self.cmd), e.strerror) + return len(content) + + def recv(self, size): + """ + Read from the standard output of the forked program. + + @param size: how many chars should be read + @type size: int + + @return: the length of the read content + @rtype: int + """ + try: + return os.read(self.process.stdout.fileno(), size) + except IOError, e: + raise BadProxyCommand(' '.join(self.cmd), e.strerror) + + def close(self): + os.kill(self.process.pid, signal.SIGTERM) + + def settimeout(self, timeout): + # Timeouts are meaningless for this implementation, but are part of the + # spec, so must be present. + pass diff --git a/paramiko/server.py b/paramiko/server.py index 6424b63..dac9bf1 100644 --- a/paramiko/server.py +++ b/paramiko/server.py @@ -93,6 +93,7 @@ class ServerInterface (object): - L{check_channel_subsystem_request} - L{check_channel_window_change_request} - L{check_channel_x11_request} + - L{check_channel_forward_agent_request} The C{chanid} parameter is a small number that uniquely identifies the channel within a L{Transport}. A L{Channel} object is not created @@ -492,7 +493,22 @@ class ServerInterface (object): @rtype: bool """ return False - + + def check_channel_forward_agent_request(self, channel): + """ + Determine if the client will be provided with an forward agent session. + If this method returns C{True}, the server will allow SSH Agent + forwarding. + + The default implementation always returns C{False}. + + @param channel: the L{Channel} the request arrived on + @type channel: L{Channel} + @return: C{True} if the AgentForward was loaded; C{False} if not + @rtype: bool + """ + return False + def check_channel_direct_tcpip_request(self, chanid, origin, destination): """ Determine if a local port forwarding channel will be granted, and diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index 79a7761..17ea493 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -198,7 +198,7 @@ class SFTPClient (BaseSFTP): Open a file on the remote server. The arguments are the same as for python's built-in C{file} (aka C{open}). A file-like object is returned, which closely mimics the behavior of a normal python file - object. + object, including the ability to be used as a context manager. The mode indicates how the file is to be opened: C{'r'} for reading, C{'w'} for writing (truncating an existing file), C{'a'} for appending, @@ -533,6 +533,56 @@ class SFTPClient (BaseSFTP): """ return self._cwd + def putfo(self, fl, remotepath, file_size=0, callback=None, confirm=True): + """ + Copy the contents of an open file object (C{fl}) to the SFTP server as + C{remotepath}. Any exception raised by operations will be passed through. + + The SFTP operations use pipelining for speed. + + @param fl: opened file or file-like object to copy + @type localpath: object + @param remotepath: the destination path on the SFTP server + @type remotepath: str + @param file_size: optional size parameter passed to callback. If none is + specified, size defaults to 0 + @type file_size: int + @param callback: optional callback function that accepts the bytes + transferred so far and the total bytes to be transferred + (since 1.7.4) + @type callback: function(int, int) + @param confirm: whether to do a stat() on the file afterwards to + confirm the file size (since 1.7.7) + @type confirm: bool + + @return: an object containing attributes about the given file + (since 1.7.4) + @rtype: SFTPAttributes + + @since: 1.4 + """ + fr = self.file(remotepath, 'wb') + fr.set_pipelined(True) + size = 0 + try: + while True: + data = fl.read(32768) + fr.write(data) + size += len(data) + if callback is not None: + callback(size, file_size) + if len(data) == 0: + break + finally: + fr.close() + if confirm: + s = self.stat(remotepath) + if s.st_size != size: + raise IOError('size mismatch in put! %d != %d' % (s.st_size, size)) + else: + s = SFTPAttributes() + return s + def put(self, localpath, remotepath, callback=None, confirm=True): """ Copy a local file (C{localpath}) to the SFTP server as C{remotepath}. @@ -562,29 +612,46 @@ class SFTPClient (BaseSFTP): file_size = os.stat(localpath).st_size fl = file(localpath, 'rb') try: - fr = self.file(remotepath, 'wb') - fr.set_pipelined(True) - size = 0 - try: - while True: - data = fl.read(32768) - if len(data) == 0: - break - fr.write(data) - size += len(data) - if callback is not None: - callback(size, file_size) - finally: - fr.close() + return self.putfo(fl, remotepath, os.stat(localpath).st_size, callback, confirm) finally: fl.close() - if confirm: - s = self.stat(remotepath) - if s.st_size != size: - raise IOError('size mismatch in put! %d != %d' % (s.st_size, size)) - else: - s = SFTPAttributes() - return s + + def getfo(self, remotepath, fl, callback=None): + """ + Copy a remote file (C{remotepath}) from the SFTP server and write to + an open file or file-like object, C{fl}. Any exception raised by + operations will be passed through. This method is primarily provided + as a convenience. + + @param remotepath: opened file or file-like object to copy to + @type remotepath: object + @param fl: the destination path on the local host or open file + object + @type localpath: str + @param callback: optional callback function that accepts the bytes + transferred so far and the total bytes to be transferred + (since 1.7.4) + @type callback: function(int, int) + @return: the number of bytes written to the opened file object + + @since: 1.4 + """ + fr = self.file(remotepath, 'rb') + file_size = self.stat(remotepath).st_size + fr.prefetch() + try: + size = 0 + while True: + data = fr.read(32768) + fl.write(data) + size += len(data) + if callback is not None: + callback(size, file_size) + if len(data) == 0: + break + finally: + fr.close() + return size def get(self, remotepath, localpath, callback=None): """ @@ -603,25 +670,12 @@ class SFTPClient (BaseSFTP): @since: 1.4 """ - fr = self.file(remotepath, 'rb') file_size = self.stat(remotepath).st_size - fr.prefetch() + fl = file(localpath, 'wb') try: - fl = file(localpath, 'wb') - try: - size = 0 - while True: - data = fr.read(32768) - if len(data) == 0: - break - fl.write(data) - size += len(data) - if callback is not None: - callback(size, file_size) - finally: - fl.close() + size = self.getfo(remotepath, fl, callback) finally: - fr.close() + fl.close() s = os.stat(localpath) if s.st_size != size: raise IOError('size mismatch in get! %d != %d' % (s.st_size, size)) @@ -641,13 +695,13 @@ class SFTPClient (BaseSFTP): msg = Message() msg.add_int(self.request_number) for item in arg: - if type(item) is int: + if isinstance(item, int): msg.add_int(item) - elif type(item) is long: + elif isinstance(item, long): msg.add_int64(item) - elif type(item) is str: + elif isinstance(item, str): msg.add_string(item) - elif type(item) is SFTPAttributes: + elif isinstance(item, SFTPAttributes): item._pack(msg) else: raise Exception('unknown type for %r type %r' % (item, type(item))) diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py index 8c5c7ac..e056d70 100644 --- a/paramiko/sftp_file.py +++ b/paramiko/sftp_file.py @@ -21,6 +21,7 @@ L{SFTPFile} """ from binascii import hexlify +from collections import deque import socket import threading import time @@ -34,6 +35,9 @@ from paramiko.sftp_attr import SFTPAttributes class SFTPFile (BufferedFile): """ Proxy object for a file on the remote server, in client mode SFTP. + + Instances of this class may be used as context managers in the same way + that built-in Python file objects are. """ # Some sftp servers will choke if you send read/write requests larger than @@ -51,6 +55,7 @@ class SFTPFile (BufferedFile): self._prefetch_data = {} self._prefetch_reads = [] self._saved_exception = None + self._reqs = deque() def __del__(self): self._close(async=True) @@ -160,12 +165,14 @@ class SFTPFile (BufferedFile): def _write(self, data): # may write less than requested if it would exceed max packet size chunk = min(len(data), self.MAX_REQUEST_SIZE) - req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk])) - if not self.pipelined or self.sftp.sock.recv_ready(): - t, msg = self.sftp._read_response(req) - if t != CMD_STATUS: - raise SFTPError('Expected status') - # convert_status already called + self._reqs.append(self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk]))) + if not self.pipelined or (len(self._reqs) > 100 and self.sftp.sock.recv_ready()): + while len(self._reqs): + req = self._reqs.popleft() + t, msg = self.sftp._read_response(req) + if t != CMD_STATUS: + raise SFTPError('Expected status') + # convert_status already called return chunk def settimeout(self, timeout): @@ -474,3 +481,9 @@ class SFTPFile (BufferedFile): x = self._saved_exception self._saved_exception = None raise x + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py index 68924d0..f2406dc 100644 --- a/paramiko/ssh_exception.py +++ b/paramiko/ssh_exception.py @@ -113,3 +113,20 @@ class BadHostKeyException (SSHException): self.key = got_key self.expected_key = expected_key + +class ProxyCommandFailure (SSHException): + """ + The "ProxyCommand" found in the .ssh/config file returned an error. + + @ivar command: The command line that is generating this exception. + @type command: str + @ivar error: The error captured from the proxy command output. + @type error: str + """ + def __init__(self, command, error): + SSHException.__init__(self, + '"ProxyCommand (%s)" returned non-zero exit status: %s' % ( + command, error + ) + ) + self.error = error diff --git a/paramiko/transport.py b/paramiko/transport.py index 30de295..fd6dab7 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -29,6 +29,7 @@ import threading import time import weakref +import paramiko from paramiko import util from paramiko.auth_handler import AuthHandler from paramiko.channel import Channel @@ -43,7 +44,9 @@ from paramiko.primes import ModulusPack from paramiko.rsakey import RSAKey from paramiko.server import ServerInterface from paramiko.sftp_client import SFTPClient -from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException +from paramiko.ssh_exception import (SSHException, BadAuthenticationType, + ChannelException, ProxyCommandFailure) +from paramiko.util import retry_on_signal from Crypto import Random from Crypto.Cipher import Blowfish, AES, DES3, ARC4 @@ -194,7 +197,7 @@ class Transport (threading.Thread): """ _PROTO_ID = '2.0' - _CLIENT_ID = 'paramiko_1.7.7.1' + _CLIENT_ID = 'paramiko_%s' % (paramiko.__version__) _preferred_ciphers = ( 'aes128-ctr', 'aes256-ctr', 'aes128-cbc', 'blowfish-cbc', 'aes256-cbc', '3des-cbc', 'arcfour128', 'arcfour256' ) @@ -288,7 +291,7 @@ class Transport (threading.Thread): addr = sockaddr sock = socket.socket(af, socket.SOCK_STREAM) try: - sock.connect((hostname, port)) + retry_on_signal(lambda: sock.connect((hostname, port))) except socket.error, e: reason = str(e) else: @@ -341,6 +344,7 @@ class Transport (threading.Thread): self._channel_counter = 1 self.window_size = 65536 self.max_packet_size = 34816 + self._forward_agent_handler = None self._x11_handler = None self._tcp_handler = None @@ -673,6 +677,20 @@ class Transport (threading.Thread): """ return self.open_channel('x11', src_addr=src_addr) + def open_forward_agent_channel(self): + """ + Request a new channel to the client, of type + C{"auth-agent@openssh.com"}. + + This is just an alias for C{open_channel('auth-agent@openssh.com')}. + @return: a new L{Channel} + @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely + """ + return self.open_channel('auth-agent@openssh.com') + def open_forwarded_tcpip_channel(self, (src_addr, src_port), (dest_addr, dest_port)): """ Request a new channel back to the client, of type C{"forwarded-tcpip"}. @@ -1481,6 +1499,14 @@ class Transport (threading.Thread): else: return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv) + def _set_forward_agent_handler(self, handler): + if handler is None: + def default_handler(channel): + self._queue_incoming_channel(channel) + self._forward_agent_handler = default_handler + else: + self._forward_agent_handler = handler + def _set_x11_handler(self, handler): # only called if a channel has turned on x11 forwarding if handler is None: @@ -1505,6 +1531,18 @@ class Transport (threading.Thread): # indefinitely, creating a GC cycle and not letting Transport ever be # GC'd. it's a bug in Thread.) + # Hold reference to 'sys' so we can test sys.modules to detect + # interpreter shutdown. + self.sys = sys + + # Required to prevent RNG errors when running inside many subprocess + # containers. + Random.atfork() + + # Hold reference to 'sys' so we can test sys.modules to detect + # interpreter shutdown. + self.sys = sys + # active=True occurs before the thread is launched, to avoid a race _active_threads.append(self) if self.server_mode: @@ -1512,94 +1550,102 @@ class Transport (threading.Thread): else: self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL)) try: - self.packetizer.write_all(self.local_version + '\r\n') - self._check_banner() - self._send_kex_init() - self._expect_packet(MSG_KEXINIT) - - while self.active: - if self.packetizer.need_rekey() and not self.in_kex: - self._send_kex_init() - try: - ptype, m = self.packetizer.read_message() - except NeedRekeyException: - continue - if ptype == MSG_IGNORE: - continue - elif ptype == MSG_DISCONNECT: - self._parse_disconnect(m) - self.active = False - self.packetizer.close() - break - elif ptype == MSG_DEBUG: - self._parse_debug(m) - continue - if len(self._expected_packet) > 0: - if ptype not in self._expected_packet: - raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype)) - self._expected_packet = tuple() - if (ptype >= 30) and (ptype <= 39): - self.kex_engine.parse_next(ptype, m) + try: + self.packetizer.write_all(self.local_version + '\r\n') + self._check_banner() + self._send_kex_init() + self._expect_packet(MSG_KEXINIT) + + while self.active: + if self.packetizer.need_rekey() and not self.in_kex: + self._send_kex_init() + try: + ptype, m = self.packetizer.read_message() + except NeedRekeyException: continue - - if ptype in self._handler_table: - self._handler_table[ptype](self, m) - elif ptype in self._channel_handler_table: - chanid = m.get_int() - chan = self._channels.get(chanid) - if chan is not None: - self._channel_handler_table[ptype](chan, m) - elif chanid in self.channels_seen: - self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid) - else: - self._log(ERROR, 'Channel request for unknown channel %d' % chanid) + if ptype == MSG_IGNORE: + continue + elif ptype == MSG_DISCONNECT: + self._parse_disconnect(m) self.active = False self.packetizer.close() - elif (self.auth_handler is not None) and (ptype in self.auth_handler._handler_table): - self.auth_handler._handler_table[ptype](self.auth_handler, m) + break + elif ptype == MSG_DEBUG: + self._parse_debug(m) + continue + if len(self._expected_packet) > 0: + if ptype not in self._expected_packet: + raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype)) + self._expected_packet = tuple() + if (ptype >= 30) and (ptype <= 39): + self.kex_engine.parse_next(ptype, m) + continue + + if ptype in self._handler_table: + self._handler_table[ptype](self, m) + elif ptype in self._channel_handler_table: + chanid = m.get_int() + chan = self._channels.get(chanid) + if chan is not None: + self._channel_handler_table[ptype](chan, m) + elif chanid in self.channels_seen: + self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid) + else: + self._log(ERROR, 'Channel request for unknown channel %d' % chanid) + self.active = False + self.packetizer.close() + elif (self.auth_handler is not None) and (ptype in self.auth_handler._handler_table): + self.auth_handler._handler_table[ptype](self.auth_handler, m) + else: + self._log(WARNING, 'Oops, unhandled type %d' % ptype) + msg = Message() + msg.add_byte(chr(MSG_UNIMPLEMENTED)) + msg.add_int(m.seqno) + self._send_message(msg) + except SSHException, e: + self._log(ERROR, 'Exception: ' + str(e)) + self._log(ERROR, util.tb_strings()) + self.saved_exception = e + except EOFError, e: + self._log(DEBUG, 'EOF in transport thread') + #self._log(DEBUG, util.tb_strings()) + self.saved_exception = e + except socket.error, e: + if type(e.args) is tuple: + emsg = '%s (%d)' % (e.args[1], e.args[0]) else: - self._log(WARNING, 'Oops, unhandled type %d' % ptype) - msg = Message() - msg.add_byte(chr(MSG_UNIMPLEMENTED)) - msg.add_int(m.seqno) - self._send_message(msg) - except SSHException, e: - self._log(ERROR, 'Exception: ' + str(e)) - self._log(ERROR, util.tb_strings()) - self.saved_exception = e - except EOFError, e: - self._log(DEBUG, 'EOF in transport thread') - #self._log(DEBUG, util.tb_strings()) - self.saved_exception = e - except socket.error, e: - if type(e.args) is tuple: - emsg = '%s (%d)' % (e.args[1], e.args[0]) - else: - emsg = e.args - self._log(ERROR, 'Socket exception: ' + emsg) - self.saved_exception = e - except Exception, e: - self._log(ERROR, 'Unknown exception: ' + str(e)) - self._log(ERROR, util.tb_strings()) - self.saved_exception = e - _active_threads.remove(self) - for chan in self._channels.values(): - chan._unlink() - if self.active: - self.active = False - self.packetizer.close() - if self.completion_event != None: - self.completion_event.set() - if self.auth_handler is not None: - self.auth_handler.abort() - for event in self.channel_events.values(): - event.set() - try: - self.lock.acquire() - self.server_accept_cv.notify() - finally: - self.lock.release() - self.sock.close() + emsg = e.args + self._log(ERROR, 'Socket exception: ' + emsg) + self.saved_exception = e + except Exception, e: + self._log(ERROR, 'Unknown exception: ' + str(e)) + self._log(ERROR, util.tb_strings()) + self.saved_exception = e + _active_threads.remove(self) + for chan in self._channels.values(): + chan._unlink() + if self.active: + self.active = False + self.packetizer.close() + if self.completion_event != None: + self.completion_event.set() + if self.auth_handler is not None: + self.auth_handler.abort() + for event in self.channel_events.values(): + event.set() + try: + self.lock.acquire() + self.server_accept_cv.notify() + finally: + self.lock.release() + self.sock.close() + except: + # Don't raise spurious 'NoneType has no attribute X' errors when we + # wake up during interpreter shutdown. Or rather -- raise + # everything *if* sys.modules (used as a convenient sentinel) + # appears to still exist. + if self.sys.modules is not None: + raise ### protocol stages @@ -1629,6 +1675,8 @@ class Transport (threading.Thread): timeout = 2 try: buf = self.packetizer.readline(timeout) + except ProxyCommandFailure: + raise except Exception, x: raise SSHException('Error reading SSH protocol banner' + str(x)) if buf[:4] == 'SSH-': @@ -1837,7 +1885,8 @@ class Transport (threading.Thread): mac_key = self._compute_key('F', mac_engine.digest_size) else: mac_key = self._compute_key('E', mac_engine.digest_size) - self.packetizer.set_outbound_cipher(engine, block_size, mac_engine, mac_size, mac_key) + sdctr = self.local_cipher.endswith('-ctr') + self.packetizer.set_outbound_cipher(engine, block_size, mac_engine, mac_size, mac_key, sdctr) compress_out = self._compression_info[self.local_compression][0] if (compress_out is not None) and ((self.local_compression != 'zlib@openssh.com') or self.authenticated): self._log(DEBUG, 'Switching on outbound compression ...') @@ -1980,7 +2029,14 @@ class Transport (threading.Thread): initial_window_size = m.get_int() max_packet_size = m.get_int() reject = False - if (kind == 'x11') and (self._x11_handler is not None): + if (kind == 'auth-agent@openssh.com') and (self._forward_agent_handler is not None): + self._log(DEBUG, 'Incoming forward agent connection') + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + elif (kind == 'x11') and (self._x11_handler is not None): origin_addr = m.get_string() origin_port = m.get_int() self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port)) @@ -2052,7 +2108,9 @@ class Transport (threading.Thread): m.add_int(self.max_packet_size) self._send_message(m) self._log(INFO, 'Secsh channel %d (%s) opened.', my_chanid, kind) - if kind == 'x11': + if kind == 'auth-agent@openssh.com': + self._forward_agent_handler(chan) + elif kind == 'x11': self._x11_handler(chan, (origin_addr, origin_port)) elif kind == 'forwarded-tcpip': chan.origin_addr = (origin_addr, origin_port) diff --git a/paramiko/util.py b/paramiko/util.py index 0d6a534..f4bfbec 100644 --- a/paramiko/util.py +++ b/paramiko/util.py @@ -24,6 +24,7 @@ from __future__ import generators import array from binascii import hexlify, unhexlify +import errno import sys import struct import traceback @@ -270,6 +271,14 @@ def get_logger(name): l.addFilter(_pfilter) return l +def retry_on_signal(function): + """Retries function until it doesn't raise an EINTR error""" + while True: + try: + return function() + except EnvironmentError, e: + if e.errno != errno.EINTR: + raise class Counter (object): """Stateful counter for CTR mode crypto""" diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py index 787032b..d77d58f 100644 --- a/paramiko/win_pageant.py +++ b/paramiko/win_pageant.py @@ -26,6 +26,8 @@ import struct import tempfile import mmap import array +import platform +import ctypes.wintypes # if you're on windows, you should have one of these, i guess? # ctypes is part of standard library since Python 2.5 @@ -42,7 +44,6 @@ except ImportError: except ImportError: pass - _AGENT_COPYDATA_ID = 0x804e50ba _AGENT_MAX_MSGLEN = 8192 # Note: The WM_COPYDATA value is pulled from win32con, as a workaround @@ -74,6 +75,17 @@ def can_talk_to_agent(): return True return False +ULONG_PTR = ctypes.c_uint64 if platform.architecture()[0] == '64bit' else ctypes.c_uint32 +class COPYDATASTRUCT(ctypes.Structure): + """ + ctypes implementation of + http://msdn.microsoft.com/en-us/library/windows/desktop/ms649010%28v=vs.85%29.aspx + """ + _fields_ = [ + ('num_data', ULONG_PTR), + ('data_size', ctypes.wintypes.DWORD), + ('data_loc', ctypes.c_void_p), + ] def _query_pageant(msg): hwnd = _get_pageant_window_object() @@ -96,19 +108,17 @@ def _query_pageant(msg): char_buffer = array.array("c", map_filename + '\0') char_buffer_address, char_buffer_size = char_buffer.buffer_info() # Create a string to use for the SendMessage function call - cds = struct.pack("LLP", _AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address) + cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address) if _has_win32all: # win32gui.SendMessage should also allow the same pattern as # ctypes, but let's keep it like this for now... - response = win32gui.SendMessage(hwnd, win32con_WM_COPYDATA, len(cds), cds) + response = win32gui.SendMessage(hwnd, win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.addressof(cds)) elif _has_ctypes: - _buf = array.array('B', cds) - _addr, _size = _buf.buffer_info() - response = ctypes.windll.user32.SendMessageA(hwnd, win32con_WM_COPYDATA, _size, _addr) + response = ctypes.windll.user32.SendMessageA(hwnd, win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds)) else: response = 0 - + if response > 0: datalen = pymap.read(4) retlen = struct.unpack('>I', datalen)[0] @@ -131,10 +141,10 @@ class PageantConnection (object): def __init__(self): self._response = None - + def send(self, data): self._response = _query_pageant(data) - + def recv(self, n): if self._response is None: return '' |