diff options
Diffstat (limited to 'paramiko/transport.py')
-rw-r--r-- | paramiko/transport.py | 574 |
1 files changed, 408 insertions, 166 deletions
diff --git a/paramiko/transport.py b/paramiko/transport.py index 8714a96..a18e05b 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -30,19 +30,20 @@ import time import weakref from paramiko import util +from paramiko.auth_handler import AuthHandler +from paramiko.channel import Channel from paramiko.common import * from paramiko.compress import ZlibCompressor, ZlibDecompressor -from paramiko.ssh_exception import SSHException, BadAuthenticationType -from paramiko.message import Message -from paramiko.channel import Channel -from paramiko.sftp_client import SFTPClient -from paramiko.packet import Packetizer, NeedRekeyException -from paramiko.rsakey import RSAKey from paramiko.dsskey import DSSKey -from paramiko.kex_group1 import KexGroup1 from paramiko.kex_gex import KexGex +from paramiko.kex_group1 import KexGroup1 +from paramiko.message import Message +from paramiko.packet import Packetizer, NeedRekeyException from paramiko.primes import ModulusPack -from paramiko.auth_handler import AuthHandler +from paramiko.rsakey import RSAKey +from paramiko.server import ServerInterface +from paramiko.sftp_client import SFTPClient +from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException # these come from PyCrypt # http://www.amk.ca/python/writing/pycrypt/ @@ -50,7 +51,7 @@ from paramiko.auth_handler import AuthHandler # PyCrypt compiled for Win32 can be downloaded from the HashTar homepage: # http://nitace.bsd.uchicago.edu:8080/hashtar from Crypto.Cipher import Blowfish, AES, DES3 -from Crypto.Hash import SHA, MD5, HMAC +from Crypto.Hash import SHA, MD5 # for thread cleanup @@ -73,8 +74,6 @@ class SecurityOptions (object): If you try to add an algorithm that paramiko doesn't recognize, C{ValueError} will be raised. If you try to assign something besides a tuple to one of the fields, C{TypeError} will be raised. - - @since: ivysaur """ __slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ] @@ -110,7 +109,8 @@ class SecurityOptions (object): if type(x) is not tuple: raise TypeError('expected tuple or list') possible = getattr(self._transport, orig).keys() - if len(filter(lambda n: n not in possible, x)) > 0: + forbidden = filter(lambda n: n not in possible, x) + if len(forbidden) > 0: raise ValueError('unknown cipher') setattr(self._transport, name, x) @@ -140,6 +140,51 @@ class SecurityOptions (object): "Compression algorithms") +class ChannelMap (object): + def __init__(self): + # (id -> Channel) + self._map = weakref.WeakValueDictionary() + self._lock = threading.Lock() + + def put(self, chanid, chan): + self._lock.acquire() + try: + self._map[chanid] = chan + finally: + self._lock.release() + + def get(self, chanid): + self._lock.acquire() + try: + return self._map.get(chanid, None) + finally: + self._lock.release() + + def delete(self, chanid): + self._lock.acquire() + try: + try: + del self._map[chanid] + except KeyError: + pass + finally: + self._lock.release() + + def values(self): + self._lock.acquire() + try: + return self._map.values() + finally: + self._lock.release() + + def __len__(self): + self._lock.acquire() + try: + return len(self._map) + finally: + self._lock.release() + + class Transport (threading.Thread): """ An SSH Transport attaches to a stream (usually a socket), negotiates an @@ -149,7 +194,7 @@ class Transport (threading.Thread): """ _PROTO_ID = '2.0' - _CLIENT_ID = 'paramiko_1.5.2' + _CLIENT_ID = 'paramiko_1.7.4' _preferred_ciphers = ( 'aes128-cbc', 'blowfish-cbc', 'aes256-cbc', '3des-cbc' ) _preferred_macs = ( 'hmac-sha1', 'hmac-md5', 'hmac-sha1-96', 'hmac-md5-96' ) @@ -245,25 +290,41 @@ class Transport (threading.Thread): self.sock.settimeout(0.1) except AttributeError: pass + # negotiated crypto parameters self.packetizer = Packetizer(sock) self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID self.remote_version = '' self.local_cipher = self.remote_cipher = '' self.local_kex_init = self.remote_kex_init = None + self.local_mac = self.remote_mac = None + self.local_compression = self.remote_compression = None self.session_id = None - # /negotiated crypto parameters - self.expected_packet = 0 + self.host_key_type = None + self.host_key = None + + # state used during negotiation + self.kex_engine = None + self.H = None + self.K = None + self.active = False self.initial_kex_done = False self.in_kex = False + self.authenticated = False + self._expected_packet = tuple() self.lock = threading.Lock() # synchronization (always higher level than write_lock) - self.channels = weakref.WeakValueDictionary() # (id -> Channel) + + # tracking open channels + self._channels = ChannelMap() self.channel_events = { } # (id -> Event) self.channels_seen = { } # (id -> True) - self.channel_counter = 1 + self._channel_counter = 1 self.window_size = 65536 self.max_packet_size = 34816 + self._x11_handler = None + self._tcp_handler = None + self.saved_exception = None self.clear_to_send = threading.Event() self.clear_to_send_lock = threading.Lock() @@ -271,9 +332,10 @@ class Transport (threading.Thread): self.logger = util.get_logger(self.log_name) self.packetizer.set_log(self.logger) self.auth_handler = None - self.authenticated = False - # user-defined event callbacks: - self.completion_event = None + self.global_response = None # response Message from an arbitrary global request + self.completion_event = None # user-defined event callbacks + self.banner_timeout = 15 # how long (seconds) to wait for the SSH banner + # server mode: self.server_mode = False self.server_object = None @@ -282,9 +344,6 @@ class Transport (threading.Thread): self.server_accept_cv = threading.Condition(self.lock) self.subsystem_table = { } - def __del__(self): - self.close() - def __repr__(self): """ Returns a string representation of this object, for debugging. @@ -299,16 +358,26 @@ class Transport (threading.Thread): out += ' (cipher %s, %d bits)' % (self.local_cipher, self._cipher_info[self.local_cipher]['key-size'] * 8) if self.is_authenticated(): - if len(self.channels) == 1: - out += ' (active; 1 open channel)' - else: - out += ' (active; %d open channels)' % len(self.channels) + out += ' (active; %d open channel(s))' % len(self._channels) elif self.initial_kex_done: out += ' (connected; awaiting auth)' else: out += ' (connecting)' out += '>' return out + + def atfork(self): + """ + Terminate this Transport without closing the session. On posix + systems, if a Transport is open during process forking, both parent + and child will share the underlying socket, but only one process can + use the connection (without corrupting the session). Use this method + to clean up a Transport object without disrupting the other process. + + @since: 1.5.3 + """ + self.sock.close() + self.close() def get_security_options(self): """ @@ -319,8 +388,6 @@ class Transport (threading.Thread): @return: an object that can be used to change the preferred algorithms for encryption, digest (hash), public key, and key exchange. @rtype: L{SecurityOptions} - - @since: ivysaur """ return SecurityOptions(self) @@ -471,7 +538,8 @@ class Transport (threading.Thread): try: return self.server_key_dict[self.host_key_type] except KeyError: - return None + pass + return None def load_server_moduli(filename=None): """ @@ -496,8 +564,6 @@ class Transport (threading.Thread): @return: True if a moduli file was successfully loaded; False otherwise. @rtype: bool - - @since: doduo @note: This has no effect when used in client mode. """ @@ -521,14 +587,13 @@ class Transport (threading.Thread): """ Close this session, and any open channels that are tied to it. """ + if not self.active: + return self.active = False - # since this may be called from __del__, can't assume any attributes exist - try: - self.packetizer.close() - for chan in self.channels.values(): - chan._unlink() - except AttributeError: - pass + self.packetizer.close() + self.join() + for chan in self._channels.values(): + chan._unlink() def get_remote_server_key(self): """ @@ -541,7 +606,7 @@ class Transport (threading.Thread): @raise SSHException: if no session is currently active. - @return: public key of the remote server. + @return: public key of the remote server @rtype: L{PKey <pkey.PKey>} """ if (not self.active) or (not self.initial_kex_done): @@ -553,7 +618,7 @@ class Transport (threading.Thread): Return true if this session is active (open). @return: True if the session is still active (open); False if the - session is closed. + session is closed @rtype: bool """ return self.active @@ -563,12 +628,43 @@ class Transport (threading.Thread): Request a new channel to the server, of type C{"session"}. This is just an alias for C{open_channel('session')}. - @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + @return: a new L{Channel} @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely """ return self.open_channel('session') + def open_x11_channel(self, src_addr=None): + """ + Request a new channel to the client, of type C{"x11"}. This + is just an alias for C{open_channel('x11', src_addr=src_addr)}. + + @param src_addr: the source address of the x11 server (port is the + x11 port, ie. 6010) + @type src_addr: (str, int) + @return: a new L{Channel} + @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely + """ + return self.open_channel('x11', src_addr=src_addr) + + def open_forwarded_tcpip_channel(self, (src_addr, src_port), (dest_addr, dest_port)): + """ + Request a new channel back to the client, of type C{"forwarded-tcpip"}. + This is used after a client has requested port forwarding, for sending + incoming connections back to the client. + + @param src_addr: originator's address + @param src_port: originator's port + @param dest_addr: local (server) connected address + @param dest_port: local (server) connected port + """ + return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port)) + def open_channel(self, kind, dest_addr=None, src_addr=None): """ Request a new channel to the server. L{Channel}s are socket-like @@ -577,18 +673,20 @@ class Transport (threading.Thread): L{connect} or L{start_client}) and authenticating. @param kind: the kind of channel requested (usually C{"session"}, - C{"forwarded-tcpip"} or C{"direct-tcpip"}). + C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"}) @type kind: str @param dest_addr: the destination address of this port forwarding, if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored - for other channel types). + for other channel types) @type dest_addr: (str, int) @param src_addr: the source address of this port forwarding, if - C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}. + C{kind} is C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"} @type src_addr: (str, int) - @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + @return: a new L{Channel} on success @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely """ chan = None if not self.active: @@ -596,11 +694,7 @@ class Transport (threading.Thread): return None self.lock.acquire() try: - chanid = self.channel_counter - while self.channels.has_key(chanid): - self.channel_counter = (self.channel_counter + 1) & 0xffffff - chanid = self.channel_counter - self.channel_counter = (self.channel_counter + 1) & 0xffffff + chanid = self._next_channel() m = Message() m.add_byte(chr(MSG_CHANNEL_OPEN)) m.add_string(kind) @@ -612,7 +706,11 @@ class Transport (threading.Thread): m.add_int(dest_addr[1]) m.add_string(src_addr[0]) m.add_int(src_addr[1]) - self.channels[chanid] = chan = Channel(chanid) + elif kind == 'x11': + m.add_string(src_addr[0]) + m.add_int(src_addr[1]) + chan = Channel(chanid) + self._channels.put(chanid, chan) self.channel_events[chanid] = event = threading.Event() self.channels_seen[chanid] = True chan._set_transport(self) @@ -620,20 +718,84 @@ class Transport (threading.Thread): finally: self.lock.release() self._send_user_message(m) - while 1: + while True: event.wait(0.1); if not self.active: - return None + e = self.get_exception() + if e is None: + e = SSHException('Unable to open channel.') + raise e if event.isSet(): break - try: - self.lock.acquire() - if not self.channels.has_key(chanid): - chan = None - finally: - self.lock.release() - return chan - + chan = self._channels.get(chanid) + if chan is not None: + return chan + e = self.get_exception() + if e is None: + e = SSHException('Unable to open channel.') + raise e + + def request_port_forward(self, address, port, handler=None): + """ + Ask the server to forward TCP connections from a listening port on + the server, across this SSH session. + + If a handler is given, that handler is called from a different thread + whenever a forwarded connection arrives. The handler parameters are:: + + handler(channel, (origin_addr, origin_port), (server_addr, server_port)) + + where C{server_addr} and C{server_port} are the address and port that + the server was listening on. + + If no handler is set, the default behavior is to send new incoming + forwarded connections into the accept queue, to be picked up via + L{accept}. + + @param address: the address to bind when forwarding + @type address: str + @param port: the port to forward, or 0 to ask the server to allocate + any port + @type port: int + @param handler: optional handler for incoming forwarded connections + @type handler: function(Channel, (str, int), (str, int)) + @return: the port # allocated by the server + @rtype: int + + @raise SSHException: if the server refused the TCP forward request + """ + if not self.active: + raise SSHException('SSH session not active') + address = str(address) + port = int(port) + response = self.global_request('tcpip-forward', (address, port), wait=True) + if response is None: + raise SSHException('TCP forwarding request denied') + if port == 0: + port = response.get_int() + if handler is None: + def default_handler(channel, (src_addr, src_port), (dest_addr, dest_port)): + self._queue_incoming_channel(channel) + handler = default_handler + self._tcp_handler = handler + return port + + def cancel_port_forward(self, address, port): + """ + Ask the server to cancel a previous port-forwarding request. No more + connections to the given address & port will be forwarded across this + ssh connection. + + @param address: the address to stop forwarding + @type address: str + @param port: the port to stop forwarding + @type port: int + """ + if not self.active: + return + self._tcp_handler = None + self.global_request('cancel-tcpip-forward', (address, port), wait=True) + def open_sftp_client(self): """ Create an SFTP client channel from an open transport. On success, @@ -656,8 +818,6 @@ class Transport (threading.Thread): @param bytes: the number of random bytes to send in the payload of the ignored packet -- defaults to a random number from 10 to 41. @type bytes: int - - @since: fearow """ m = Message() m.add_byte(chr(MSG_IGNORE)) @@ -674,22 +834,23 @@ class Transport (threading.Thread): bytes sent or received, but this method gives you the option of forcing new keys whenever you want. Negotiating new keys causes a pause in traffic both ways as the two sides swap keys and do computations. This - method returns when the session has switched to new keys, or the - session has died mid-negotiation. + method returns when the session has switched to new keys. - @return: True if the renegotiation was successful, and the link is - using new keys; False if the session dropped during renegotiation. - @rtype: bool + @raise SSHException: if the key renegotiation failed (which causes the + session to end) """ self.completion_event = threading.Event() self._send_kex_init() - while 1: - self.completion_event.wait(0.1); + while True: + self.completion_event.wait(0.1) if not self.active: - return False + e = self.get_exception() + if e is not None: + raise e + raise SSHException('Negotiation failed.') if self.completion_event.isSet(): break - return True + return def set_keepalive(self, interval): """ @@ -701,11 +862,9 @@ class Transport (threading.Thread): @param interval: seconds to wait before sending a keepalive packet (or 0 to disable keepalives). @type interval: int - - @since: fearow """ self.packetizer.set_keepalive(interval, - lambda x=self: x.global_request('keepalive@lag.net', wait=False)) + lambda x=weakref.proxy(self): x.global_request('keepalive@lag.net', wait=False)) def global_request(self, kind, data=None, wait=True): """ @@ -724,8 +883,6 @@ class Transport (threading.Thread): request was successful (or an empty L{Message} if C{wait} was C{False}); C{None} if the request was denied. @rtype: L{Message} - - @since: fearow """ if wait: self.completion_event = threading.Event() @@ -807,8 +964,6 @@ class Transport (threading.Thread): @raise SSHException: if the SSH2 negotiation fails, the host key supplied by the server is incorrect, or authentication fails. - - @since: doduo """ if hostkey is not None: self._preferred_keys = [ hostkey.get_name() ] @@ -896,8 +1051,6 @@ class Transport (threading.Thread): @return: username that was authenticated, or C{None}. @rtype: string - - @since: fearow """ if not self.active or (self.auth_handler is None): return None @@ -958,9 +1111,9 @@ class Transport (threading.Thread): step. Otherwise, in the normal case, an empty list is returned. @param username: the username to authenticate as - @type username: string + @type username: str @param password: the password to authenticate with - @type password: string + @type password: str or unicode @param event: an event to trigger when the authentication attempt is complete (whether it was successful or not) @type event: threading.Event @@ -974,8 +1127,9 @@ class Transport (threading.Thread): @raise BadAuthenticationType: if password authentication isn't allowed by the server for this user (and no event was passed in) - @raise SSHException: if the authentication failed (and no event was - passed in) + @raise AuthenticationException: if the authentication failed (and no + event was passed in) + @raise SSHException: if there was a network error """ if (not self.active) or (not self.initial_kex_done): # we should never try to send the password unless we're on a secure link @@ -993,7 +1147,7 @@ class Transport (threading.Thread): return self.auth_handler.wait_for_response(my_event) except BadAuthenticationType, x: # if password auth isn't allowed, but keyboard-interactive *is*, try to fudge it - if not fallback or not 'keyboard-interactive' in x.allowed_types: + if not fallback or ('keyboard-interactive' not in x.allowed_types): raise try: def handler(title, instructions, fields): @@ -1010,6 +1164,7 @@ class Transport (threading.Thread): except SSHException, ignored: # attempt failed; just raise the original exception raise x + return None def auth_publickey(self, username, key, event=None): """ @@ -1037,13 +1192,14 @@ class Transport (threading.Thread): complete (whether it was successful or not) @type event: threading.Event @return: list of auth types permissible for the next stage of - authentication (normally empty). + authentication (normally empty) @rtype: list @raise BadAuthenticationType: if public-key authentication isn't - allowed by the server for this user (and no event was passed in). - @raise SSHException: if the authentication failed (and no event was - passed in). + allowed by the server for this user (and no event was passed in) + @raise AuthenticationException: if the authentication failed (and no + event was passed in) + @raise SSHException: if there was a network error """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link @@ -1100,7 +1256,8 @@ class Transport (threading.Thread): @raise BadAuthenticationType: if public-key authentication isn't allowed by the server for this user - @raise SSHException: if the authentication failed + @raise AuthenticationException: if the authentication failed + @raise SSHException: if there was a network error @since: 1.5 """ @@ -1119,13 +1276,14 @@ class Transport (threading.Thread): (See the C{logging} module for more info.) SSH Channels will log to a sub-channel of the one specified. - @param name: new channel name for logging. + @param name: new channel name for logging @type name: str @since: 1.1 """ self.log_name = name self.logger = util.get_logger(name) + self.packetizer.set_log(self.logger) def get_log_channel(self): """ @@ -1166,8 +1324,7 @@ class Transport (threading.Thread): """ Turn on/off compression. This will only have an affect before starting the transport (ie before calling L{connect}, etc). By default, - compression is off since it negatively affects interactive sessions - and is not fully tested. + compression is off since it negatively affects interactive sessions. @param compress: C{True} to ask the remote client/server to compress traffic; C{False} to refuse compression @@ -1179,6 +1336,21 @@ class Transport (threading.Thread): self._preferred_compression = ( 'zlib@openssh.com', 'zlib', 'none' ) else: self._preferred_compression = ( 'none', ) + + def getpeername(self): + """ + Return the address of the remote side of this Transport, if possible. + This is effectively a wrapper around C{'getpeername'} on the underlying + socket. If the socket-like object has no C{'getpeername'} method, + then C{("unknown", 0)} is returned. + + @return: the address if the remote host, if known + @rtype: tuple(str, int) + """ + gp = getattr(self.sock, 'getpeername', None) + if gp is None: + return ('unknown', 0) + return gp() def stop_thread(self): self.active = False @@ -1188,25 +1360,29 @@ class Transport (threading.Thread): ### internals... - def _log(self, level, msg): + def _log(self, level, msg, *args): if issubclass(type(msg), list): for m in msg: self.logger.log(level, m) else: - self.logger.log(level, msg) + self.logger.log(level, msg, *args) def _get_modulus_pack(self): "used by KexGex to find primes for group exchange" return self._modulus_pack + def _next_channel(self): + "you are holding the lock" + chanid = self._channel_counter + while self._channels.get(chanid) is not None: + self._channel_counter = (self._channel_counter + 1) & 0xffffff + chanid = self._channel_counter + self._channel_counter = (self._channel_counter + 1) & 0xffffff + return chanid + def _unlink_channel(self, chanid): "used by a Channel to remove itself from the active channel list" - try: - self.lock.acquire() - if self.channels.has_key(chanid): - del self.channels[chanid] - finally: - self.lock.release() + self._channels.delete(chanid) def _send_message(self, data): self.packetizer.send_message(data) @@ -1237,9 +1413,9 @@ class Transport (threading.Thread): if self.session_id == None: self.session_id = h - def _expect_packet(self, type): + def _expect_packet(self, *ptypes): "used by a kex object to register the next packet type it expects to see" - self.expected_packet = type + self._expected_packet = tuple(ptypes) def _verify_key(self, host_key, sig): key = self._key_info[self.host_key_type](Message(host_key)) @@ -1262,16 +1438,34 @@ class Transport (threading.Thread): m.add_mpint(self.K) m.add_bytes(self.H) m.add_bytes(sofar) - hash = SHA.new(str(m)).digest() - out += hash - sofar += hash + digest = SHA.new(str(m)).digest() + out += digest + sofar += digest return out[:nbytes] def _get_cipher(self, name, key, iv): - if not self._cipher_info.has_key(name): + if name not in self._cipher_info: raise SSHException('Unknown client cipher ' + name) return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv) + def _set_x11_handler(self, handler): + # only called if a channel has turned on x11 forwarding + if handler is None: + # by default, use the same mechanism as accept() + def default_handler(channel, (src_addr, src_port)): + self._queue_incoming_channel(channel) + self._x11_handler = default_handler + else: + self._x11_handler = handler + + def _queue_incoming_channel(self, channel): + self.lock.acquire() + try: + self.server_accepts.append(channel) + self.server_accept_cv.notify() + finally: + self.lock.release() + def run(self): # (use the exposed "run" method, because if we specify a thread target # of a private method, threading.Thread will keep a reference to it @@ -1288,7 +1482,7 @@ class Transport (threading.Thread): self.packetizer.write_all(self.local_version + '\r\n') self._check_banner() self._send_kex_init() - self.expected_packet = MSG_KEXINIT + self._expect_packet(MSG_KEXINIT) while self.active: if self.packetizer.need_rekey() and not self.in_kex: @@ -1307,27 +1501,28 @@ class Transport (threading.Thread): elif ptype == MSG_DEBUG: self._parse_debug(m) continue - if self.expected_packet != 0: - if ptype != self.expected_packet: - raise SSHException('Expecting packet %d, got %d' % (self.expected_packet, ptype)) - self.expected_packet = 0 + if len(self._expected_packet) > 0: + if ptype not in self._expected_packet: + raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype)) + self._expected_packet = tuple() if (ptype >= 30) and (ptype <= 39): self.kex_engine.parse_next(ptype, m) continue - if self._handler_table.has_key(ptype): + if ptype in self._handler_table: self._handler_table[ptype](self, m) - elif self._channel_handler_table.has_key(ptype): + elif ptype in self._channel_handler_table: chanid = m.get_int() - if self.channels.has_key(chanid): - self._channel_handler_table[ptype](self.channels[chanid], m) - elif self.channels_seen.has_key(chanid): + chan = self._channels.get(chanid) + if chan is not None: + self._channel_handler_table[ptype](chan, m) + elif chanid in self.channels_seen: self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid) else: self._log(ERROR, 'Channel request for unknown channel %d' % chanid) self.active = False self.packetizer.close() - elif (self.auth_handler is not None) and self.auth_handler._handler_table.has_key(ptype): + elif (self.auth_handler is not None) and (ptype in self.auth_handler._handler_table): self.auth_handler._handler_table[ptype](self.auth_handler, m) else: self._log(WARNING, 'Oops, unhandled type %d' % ptype) @@ -1355,7 +1550,7 @@ class Transport (threading.Thread): self._log(ERROR, util.tb_strings()) self.saved_exception = e _active_threads.remove(self) - for chan in self.channels.values(): + for chan in self._channels.values(): chan._unlink() if self.active: self.active = False @@ -1366,6 +1561,11 @@ class Transport (threading.Thread): self.auth_handler.abort() for event in self.channel_events.values(): event.set() + try: + self.lock.acquire() + self.server_accept_cv.notify() + finally: + self.lock.release() self.sock.close() @@ -1388,30 +1588,31 @@ class Transport (threading.Thread): def _check_banner(self): # this is slow, but we only have to do it once for i in range(5): - # give them 5 seconds for the first line, then just 2 seconds each additional line + # give them 15 seconds for the first line, then just 2 seconds + # each additional line. (some sites have very high latency.) if i == 0: - timeout = 5 + timeout = self.banner_timeout else: timeout = 2 try: - buffer = self.packetizer.readline(timeout) + buf = self.packetizer.readline(timeout) except Exception, x: raise SSHException('Error reading SSH protocol banner' + str(x)) - if buffer[:4] == 'SSH-': + if buf[:4] == 'SSH-': break - self._log(DEBUG, 'Banner: ' + buffer) - if buffer[:4] != 'SSH-': - raise SSHException('Indecipherable protocol version "' + buffer + '"') + self._log(DEBUG, 'Banner: ' + buf) + if buf[:4] != 'SSH-': + raise SSHException('Indecipherable protocol version "' + buf + '"') # save this server version string for later - self.remote_version = buffer + self.remote_version = buf # pull off any attached comment comment = '' - i = string.find(buffer, ' ') + i = string.find(buf, ' ') if i >= 0: - comment = buffer[i+1:] - buffer = buffer[:i] + comment = buf[i+1:] + buf = buf[:i] # parse out version string and make sure it matches - segs = buffer.split('-', 2) + segs = buf.split('-', 2) if len(segs) < 3: raise SSHException('Invalid SSH banner') version = segs[1] @@ -1612,7 +1813,7 @@ class Transport (threading.Thread): if not self.packetizer.need_rekey(): self.in_kex = False # we always expect to receive NEWKEYS now - self.expected_packet = MSG_NEWKEYS + self._expect_packet(MSG_NEWKEYS) def _auth_trigger(self): self.authenticated = True @@ -1661,7 +1862,22 @@ class Transport (threading.Thread): kind = m.get_string() self._log(DEBUG, 'Received global request "%s"' % kind) want_reply = m.get_boolean() - ok = self.server_object.check_global_request(kind, m) + if not self.server_mode: + self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind) + ok = False + elif kind == 'tcpip-forward': + address = m.get_string() + port = m.get_int() + ok = self.server_object.check_port_forward_request(address, port) + if ok != False: + ok = (ok,) + elif kind == 'cancel-tcpip-forward': + address = m.get_string() + port = m.get_int() + self.server_object.cancel_port_forward_request(address, port) + ok = True + else: + ok = self.server_object.check_global_request(kind, m) extra = () if type(ok) is tuple: extra = ok @@ -1692,15 +1908,15 @@ class Transport (threading.Thread): server_chanid = m.get_int() server_window_size = m.get_int() server_max_packet_size = m.get_int() - if not self.channels.has_key(chanid): + chan = self._channels.get(chanid) + if chan is None: self._log(WARNING, 'Success for unrequested channel! [??]') return self.lock.acquire() try: - chan = self.channels[chanid] chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size) self._log(INFO, 'Secsh channel %d opened.' % chanid) - if self.channel_events.has_key(chanid): + if chanid in self.channel_events: self.channel_events[chanid].set() del self.channel_events[chanid] finally: @@ -1712,16 +1928,14 @@ class Transport (threading.Thread): reason = m.get_int() reason_str = m.get_string() lang = m.get_string() - if CONNECTION_FAILED_CODE.has_key(reason): - reason_text = CONNECTION_FAILED_CODE[reason] - else: - reason_text = '(unknown code)' + reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)') self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) + self.lock.acquire() try: - self.lock.aquire() - if self.channels.has_key(chanid): - del self.channels[chanid] - if self.channel_events.has_key(chanid): + self.saved_exception = ChannelException(reason, reason_text) + if chanid in self.channel_events: + self._channels.delete(chanid) + if chanid in self.channel_events: self.channel_events[chanid].set() del self.channel_events[chanid] finally: @@ -1734,21 +1948,47 @@ class Transport (threading.Thread): initial_window_size = m.get_int() max_packet_size = m.get_int() reject = False - if not self.server_mode: + if (kind == 'x11') and (self._x11_handler is not None): + origin_addr = m.get_string() + origin_port = m.get_int() + self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port)) + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None): + server_addr = m.get_string() + server_port = m.get_int() + origin_addr = m.get_string() + origin_port = m.get_int() + self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port)) + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + elif not self.server_mode: self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) reject = True reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: self.lock.acquire() try: - my_chanid = self.channel_counter - while self.channels.has_key(my_chanid): - self.channel_counter = (self.channel_counter + 1) & 0xffffff - my_chanid = self.channel_counter - self.channel_counter = (self.channel_counter + 1) & 0xffffff + my_chanid = self._next_channel() finally: self.lock.release() - reason = self.server_object.check_channel_request(kind, my_chanid) + if kind == 'direct-tcpip': + # handle direct-tcpip requests comming from the client + dest_addr = m.get_string() + dest_port = m.get_int() + origin_addr = m.get_string() + origin_port = m.get_int() + reason = self.server_object.check_channel_direct_tcpip_request( + my_chanid, (origin_addr, origin_port), + (dest_addr, dest_port)) + else: + reason = self.server_object.check_channel_request(kind, my_chanid) if reason != OPEN_SUCCEEDED: self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) reject = True @@ -1761,10 +2001,11 @@ class Transport (threading.Thread): msg.add_string('en') self._send_message(msg) return + chan = Channel(my_chanid) + self.lock.acquire() try: - self.lock.acquire() - self.channels[my_chanid] = chan + self._channels.put(my_chanid, chan) self.channels_seen[my_chanid] = True chan._set_transport(self) chan._set_window(self.window_size, self.max_packet_size) @@ -1778,13 +2019,14 @@ class Transport (threading.Thread): m.add_int(self.window_size) m.add_int(self.max_packet_size) self._send_message(m) - self._log(INFO, 'Secsh channel %d opened.' % my_chanid) - try: - self.lock.acquire() - self.server_accepts.append(chan) - self.server_accept_cv.notify() - finally: - self.lock.release() + self._log(INFO, 'Secsh channel %d (%s) opened.', my_chanid, kind) + if kind == 'x11': + self._x11_handler(chan, (origin_addr, origin_port)) + elif kind == 'forwarded-tcpip': + chan.origin_addr = (origin_addr, origin_port) + self._tcp_handler(chan, (origin_addr, origin_port), (server_addr, server_port)) + else: + self._queue_incoming_channel(chan) def _parse_debug(self, m): always_display = m.get_boolean() @@ -1795,7 +2037,7 @@ class Transport (threading.Thread): def _get_subsystem_handler(self, name): try: self.lock.acquire() - if not self.subsystem_table.has_key(name): + if name not in self.subsystem_table: return (None, [], {}) return self.subsystem_table[name] finally: |