summaryrefslogtreecommitdiff
path: root/paramiko/transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'paramiko/transport.py')
-rw-r--r--paramiko/transport.py574
1 files changed, 408 insertions, 166 deletions
diff --git a/paramiko/transport.py b/paramiko/transport.py
index 8714a96..a18e05b 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
+# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
@@ -30,19 +30,20 @@ import time
import weakref
from paramiko import util
+from paramiko.auth_handler import AuthHandler
+from paramiko.channel import Channel
from paramiko.common import *
from paramiko.compress import ZlibCompressor, ZlibDecompressor
-from paramiko.ssh_exception import SSHException, BadAuthenticationType
-from paramiko.message import Message
-from paramiko.channel import Channel
-from paramiko.sftp_client import SFTPClient
-from paramiko.packet import Packetizer, NeedRekeyException
-from paramiko.rsakey import RSAKey
from paramiko.dsskey import DSSKey
-from paramiko.kex_group1 import KexGroup1
from paramiko.kex_gex import KexGex
+from paramiko.kex_group1 import KexGroup1
+from paramiko.message import Message
+from paramiko.packet import Packetizer, NeedRekeyException
from paramiko.primes import ModulusPack
-from paramiko.auth_handler import AuthHandler
+from paramiko.rsakey import RSAKey
+from paramiko.server import ServerInterface
+from paramiko.sftp_client import SFTPClient
+from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException
# these come from PyCrypt
# http://www.amk.ca/python/writing/pycrypt/
@@ -50,7 +51,7 @@ from paramiko.auth_handler import AuthHandler
# PyCrypt compiled for Win32 can be downloaded from the HashTar homepage:
# http://nitace.bsd.uchicago.edu:8080/hashtar
from Crypto.Cipher import Blowfish, AES, DES3
-from Crypto.Hash import SHA, MD5, HMAC
+from Crypto.Hash import SHA, MD5
# for thread cleanup
@@ -73,8 +74,6 @@ class SecurityOptions (object):
If you try to add an algorithm that paramiko doesn't recognize,
C{ValueError} will be raised. If you try to assign something besides a
tuple to one of the fields, C{TypeError} will be raised.
-
- @since: ivysaur
"""
__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ]
@@ -110,7 +109,8 @@ class SecurityOptions (object):
if type(x) is not tuple:
raise TypeError('expected tuple or list')
possible = getattr(self._transport, orig).keys()
- if len(filter(lambda n: n not in possible, x)) > 0:
+ forbidden = filter(lambda n: n not in possible, x)
+ if len(forbidden) > 0:
raise ValueError('unknown cipher')
setattr(self._transport, name, x)
@@ -140,6 +140,51 @@ class SecurityOptions (object):
"Compression algorithms")
+class ChannelMap (object):
+ def __init__(self):
+ # (id -> Channel)
+ self._map = weakref.WeakValueDictionary()
+ self._lock = threading.Lock()
+
+ def put(self, chanid, chan):
+ self._lock.acquire()
+ try:
+ self._map[chanid] = chan
+ finally:
+ self._lock.release()
+
+ def get(self, chanid):
+ self._lock.acquire()
+ try:
+ return self._map.get(chanid, None)
+ finally:
+ self._lock.release()
+
+ def delete(self, chanid):
+ self._lock.acquire()
+ try:
+ try:
+ del self._map[chanid]
+ except KeyError:
+ pass
+ finally:
+ self._lock.release()
+
+ def values(self):
+ self._lock.acquire()
+ try:
+ return self._map.values()
+ finally:
+ self._lock.release()
+
+ def __len__(self):
+ self._lock.acquire()
+ try:
+ return len(self._map)
+ finally:
+ self._lock.release()
+
+
class Transport (threading.Thread):
"""
An SSH Transport attaches to a stream (usually a socket), negotiates an
@@ -149,7 +194,7 @@ class Transport (threading.Thread):
"""
_PROTO_ID = '2.0'
- _CLIENT_ID = 'paramiko_1.5.2'
+ _CLIENT_ID = 'paramiko_1.7.4'
_preferred_ciphers = ( 'aes128-cbc', 'blowfish-cbc', 'aes256-cbc', '3des-cbc' )
_preferred_macs = ( 'hmac-sha1', 'hmac-md5', 'hmac-sha1-96', 'hmac-md5-96' )
@@ -245,25 +290,41 @@ class Transport (threading.Thread):
self.sock.settimeout(0.1)
except AttributeError:
pass
+
# negotiated crypto parameters
self.packetizer = Packetizer(sock)
self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID
self.remote_version = ''
self.local_cipher = self.remote_cipher = ''
self.local_kex_init = self.remote_kex_init = None
+ self.local_mac = self.remote_mac = None
+ self.local_compression = self.remote_compression = None
self.session_id = None
- # /negotiated crypto parameters
- self.expected_packet = 0
+ self.host_key_type = None
+ self.host_key = None
+
+ # state used during negotiation
+ self.kex_engine = None
+ self.H = None
+ self.K = None
+
self.active = False
self.initial_kex_done = False
self.in_kex = False
+ self.authenticated = False
+ self._expected_packet = tuple()
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
- self.channels = weakref.WeakValueDictionary() # (id -> Channel)
+
+ # tracking open channels
+ self._channels = ChannelMap()
self.channel_events = { } # (id -> Event)
self.channels_seen = { } # (id -> True)
- self.channel_counter = 1
+ self._channel_counter = 1
self.window_size = 65536
self.max_packet_size = 34816
+ self._x11_handler = None
+ self._tcp_handler = None
+
self.saved_exception = None
self.clear_to_send = threading.Event()
self.clear_to_send_lock = threading.Lock()
@@ -271,9 +332,10 @@ class Transport (threading.Thread):
self.logger = util.get_logger(self.log_name)
self.packetizer.set_log(self.logger)
self.auth_handler = None
- self.authenticated = False
- # user-defined event callbacks:
- self.completion_event = None
+ self.global_response = None # response Message from an arbitrary global request
+ self.completion_event = None # user-defined event callbacks
+ self.banner_timeout = 15 # how long (seconds) to wait for the SSH banner
+
# server mode:
self.server_mode = False
self.server_object = None
@@ -282,9 +344,6 @@ class Transport (threading.Thread):
self.server_accept_cv = threading.Condition(self.lock)
self.subsystem_table = { }
- def __del__(self):
- self.close()
-
def __repr__(self):
"""
Returns a string representation of this object, for debugging.
@@ -299,16 +358,26 @@ class Transport (threading.Thread):
out += ' (cipher %s, %d bits)' % (self.local_cipher,
self._cipher_info[self.local_cipher]['key-size'] * 8)
if self.is_authenticated():
- if len(self.channels) == 1:
- out += ' (active; 1 open channel)'
- else:
- out += ' (active; %d open channels)' % len(self.channels)
+ out += ' (active; %d open channel(s))' % len(self._channels)
elif self.initial_kex_done:
out += ' (connected; awaiting auth)'
else:
out += ' (connecting)'
out += '>'
return out
+
+ def atfork(self):
+ """
+ Terminate this Transport without closing the session. On posix
+ systems, if a Transport is open during process forking, both parent
+ and child will share the underlying socket, but only one process can
+ use the connection (without corrupting the session). Use this method
+ to clean up a Transport object without disrupting the other process.
+
+ @since: 1.5.3
+ """
+ self.sock.close()
+ self.close()
def get_security_options(self):
"""
@@ -319,8 +388,6 @@ class Transport (threading.Thread):
@return: an object that can be used to change the preferred algorithms
for encryption, digest (hash), public key, and key exchange.
@rtype: L{SecurityOptions}
-
- @since: ivysaur
"""
return SecurityOptions(self)
@@ -471,7 +538,8 @@ class Transport (threading.Thread):
try:
return self.server_key_dict[self.host_key_type]
except KeyError:
- return None
+ pass
+ return None
def load_server_moduli(filename=None):
"""
@@ -496,8 +564,6 @@ class Transport (threading.Thread):
@return: True if a moduli file was successfully loaded; False
otherwise.
@rtype: bool
-
- @since: doduo
@note: This has no effect when used in client mode.
"""
@@ -521,14 +587,13 @@ class Transport (threading.Thread):
"""
Close this session, and any open channels that are tied to it.
"""
+ if not self.active:
+ return
self.active = False
- # since this may be called from __del__, can't assume any attributes exist
- try:
- self.packetizer.close()
- for chan in self.channels.values():
- chan._unlink()
- except AttributeError:
- pass
+ self.packetizer.close()
+ self.join()
+ for chan in self._channels.values():
+ chan._unlink()
def get_remote_server_key(self):
"""
@@ -541,7 +606,7 @@ class Transport (threading.Thread):
@raise SSHException: if no session is currently active.
- @return: public key of the remote server.
+ @return: public key of the remote server
@rtype: L{PKey <pkey.PKey>}
"""
if (not self.active) or (not self.initial_kex_done):
@@ -553,7 +618,7 @@ class Transport (threading.Thread):
Return true if this session is active (open).
@return: True if the session is still active (open); False if the
- session is closed.
+ session is closed
@rtype: bool
"""
return self.active
@@ -563,12 +628,43 @@ class Transport (threading.Thread):
Request a new channel to the server, of type C{"session"}. This
is just an alias for C{open_channel('session')}.
- @return: a new L{Channel} on success, or C{None} if the request is
- rejected or the session ends prematurely.
+ @return: a new L{Channel}
@rtype: L{Channel}
+
+ @raise SSHException: if the request is rejected or the session ends
+ prematurely
"""
return self.open_channel('session')
+ def open_x11_channel(self, src_addr=None):
+ """
+ Request a new channel to the client, of type C{"x11"}. This
+ is just an alias for C{open_channel('x11', src_addr=src_addr)}.
+
+ @param src_addr: the source address of the x11 server (port is the
+ x11 port, ie. 6010)
+ @type src_addr: (str, int)
+ @return: a new L{Channel}
+ @rtype: L{Channel}
+
+ @raise SSHException: if the request is rejected or the session ends
+ prematurely
+ """
+ return self.open_channel('x11', src_addr=src_addr)
+
+ def open_forwarded_tcpip_channel(self, (src_addr, src_port), (dest_addr, dest_port)):
+ """
+ Request a new channel back to the client, of type C{"forwarded-tcpip"}.
+ This is used after a client has requested port forwarding, for sending
+ incoming connections back to the client.
+
+ @param src_addr: originator's address
+ @param src_port: originator's port
+ @param dest_addr: local (server) connected address
+ @param dest_port: local (server) connected port
+ """
+ return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port))
+
def open_channel(self, kind, dest_addr=None, src_addr=None):
"""
Request a new channel to the server. L{Channel}s are socket-like
@@ -577,18 +673,20 @@ class Transport (threading.Thread):
L{connect} or L{start_client}) and authenticating.
@param kind: the kind of channel requested (usually C{"session"},
- C{"forwarded-tcpip"} or C{"direct-tcpip"}).
+ C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"})
@type kind: str
@param dest_addr: the destination address of this port forwarding,
if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
- for other channel types).
+ for other channel types)
@type dest_addr: (str, int)
@param src_addr: the source address of this port forwarding, if
- C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}.
+ C{kind} is C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"}
@type src_addr: (str, int)
- @return: a new L{Channel} on success, or C{None} if the request is
- rejected or the session ends prematurely.
+ @return: a new L{Channel} on success
@rtype: L{Channel}
+
+ @raise SSHException: if the request is rejected or the session ends
+ prematurely
"""
chan = None
if not self.active:
@@ -596,11 +694,7 @@ class Transport (threading.Thread):
return None
self.lock.acquire()
try:
- chanid = self.channel_counter
- while self.channels.has_key(chanid):
- self.channel_counter = (self.channel_counter + 1) & 0xffffff
- chanid = self.channel_counter
- self.channel_counter = (self.channel_counter + 1) & 0xffffff
+ chanid = self._next_channel()
m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN))
m.add_string(kind)
@@ -612,7 +706,11 @@ class Transport (threading.Thread):
m.add_int(dest_addr[1])
m.add_string(src_addr[0])
m.add_int(src_addr[1])
- self.channels[chanid] = chan = Channel(chanid)
+ elif kind == 'x11':
+ m.add_string(src_addr[0])
+ m.add_int(src_addr[1])
+ chan = Channel(chanid)
+ self._channels.put(chanid, chan)
self.channel_events[chanid] = event = threading.Event()
self.channels_seen[chanid] = True
chan._set_transport(self)
@@ -620,20 +718,84 @@ class Transport (threading.Thread):
finally:
self.lock.release()
self._send_user_message(m)
- while 1:
+ while True:
event.wait(0.1);
if not self.active:
- return None
+ e = self.get_exception()
+ if e is None:
+ e = SSHException('Unable to open channel.')
+ raise e
if event.isSet():
break
- try:
- self.lock.acquire()
- if not self.channels.has_key(chanid):
- chan = None
- finally:
- self.lock.release()
- return chan
-
+ chan = self._channels.get(chanid)
+ if chan is not None:
+ return chan
+ e = self.get_exception()
+ if e is None:
+ e = SSHException('Unable to open channel.')
+ raise e
+
+ def request_port_forward(self, address, port, handler=None):
+ """
+ Ask the server to forward TCP connections from a listening port on
+ the server, across this SSH session.
+
+ If a handler is given, that handler is called from a different thread
+ whenever a forwarded connection arrives. The handler parameters are::
+
+ handler(channel, (origin_addr, origin_port), (server_addr, server_port))
+
+ where C{server_addr} and C{server_port} are the address and port that
+ the server was listening on.
+
+ If no handler is set, the default behavior is to send new incoming
+ forwarded connections into the accept queue, to be picked up via
+ L{accept}.
+
+ @param address: the address to bind when forwarding
+ @type address: str
+ @param port: the port to forward, or 0 to ask the server to allocate
+ any port
+ @type port: int
+ @param handler: optional handler for incoming forwarded connections
+ @type handler: function(Channel, (str, int), (str, int))
+ @return: the port # allocated by the server
+ @rtype: int
+
+ @raise SSHException: if the server refused the TCP forward request
+ """
+ if not self.active:
+ raise SSHException('SSH session not active')
+ address = str(address)
+ port = int(port)
+ response = self.global_request('tcpip-forward', (address, port), wait=True)
+ if response is None:
+ raise SSHException('TCP forwarding request denied')
+ if port == 0:
+ port = response.get_int()
+ if handler is None:
+ def default_handler(channel, (src_addr, src_port), (dest_addr, dest_port)):
+ self._queue_incoming_channel(channel)
+ handler = default_handler
+ self._tcp_handler = handler
+ return port
+
+ def cancel_port_forward(self, address, port):
+ """
+ Ask the server to cancel a previous port-forwarding request. No more
+ connections to the given address & port will be forwarded across this
+ ssh connection.
+
+ @param address: the address to stop forwarding
+ @type address: str
+ @param port: the port to stop forwarding
+ @type port: int
+ """
+ if not self.active:
+ return
+ self._tcp_handler = None
+ self.global_request('cancel-tcpip-forward', (address, port), wait=True)
+
def open_sftp_client(self):
"""
Create an SFTP client channel from an open transport. On success,
@@ -656,8 +818,6 @@ class Transport (threading.Thread):
@param bytes: the number of random bytes to send in the payload of the
ignored packet -- defaults to a random number from 10 to 41.
@type bytes: int
-
- @since: fearow
"""
m = Message()
m.add_byte(chr(MSG_IGNORE))
@@ -674,22 +834,23 @@ class Transport (threading.Thread):
bytes sent or received, but this method gives you the option of forcing
new keys whenever you want. Negotiating new keys causes a pause in
traffic both ways as the two sides swap keys and do computations. This
- method returns when the session has switched to new keys, or the
- session has died mid-negotiation.
+ method returns when the session has switched to new keys.
- @return: True if the renegotiation was successful, and the link is
- using new keys; False if the session dropped during renegotiation.
- @rtype: bool
+ @raise SSHException: if the key renegotiation failed (which causes the
+ session to end)
"""
self.completion_event = threading.Event()
self._send_kex_init()
- while 1:
- self.completion_event.wait(0.1);
+ while True:
+ self.completion_event.wait(0.1)
if not self.active:
- return False
+ e = self.get_exception()
+ if e is not None:
+ raise e
+ raise SSHException('Negotiation failed.')
if self.completion_event.isSet():
break
- return True
+ return
def set_keepalive(self, interval):
"""
@@ -701,11 +862,9 @@ class Transport (threading.Thread):
@param interval: seconds to wait before sending a keepalive packet (or
0 to disable keepalives).
@type interval: int
-
- @since: fearow
"""
self.packetizer.set_keepalive(interval,
- lambda x=self: x.global_request('keepalive@lag.net', wait=False))
+ lambda x=weakref.proxy(self): x.global_request('keepalive@lag.net', wait=False))
def global_request(self, kind, data=None, wait=True):
"""
@@ -724,8 +883,6 @@ class Transport (threading.Thread):
request was successful (or an empty L{Message} if C{wait} was
C{False}); C{None} if the request was denied.
@rtype: L{Message}
-
- @since: fearow
"""
if wait:
self.completion_event = threading.Event()
@@ -807,8 +964,6 @@ class Transport (threading.Thread):
@raise SSHException: if the SSH2 negotiation fails, the host key
supplied by the server is incorrect, or authentication fails.
-
- @since: doduo
"""
if hostkey is not None:
self._preferred_keys = [ hostkey.get_name() ]
@@ -896,8 +1051,6 @@ class Transport (threading.Thread):
@return: username that was authenticated, or C{None}.
@rtype: string
-
- @since: fearow
"""
if not self.active or (self.auth_handler is None):
return None
@@ -958,9 +1111,9 @@ class Transport (threading.Thread):
step. Otherwise, in the normal case, an empty list is returned.
@param username: the username to authenticate as
- @type username: string
+ @type username: str
@param password: the password to authenticate with
- @type password: string
+ @type password: str or unicode
@param event: an event to trigger when the authentication attempt is
complete (whether it was successful or not)
@type event: threading.Event
@@ -974,8 +1127,9 @@ class Transport (threading.Thread):
@raise BadAuthenticationType: if password authentication isn't
allowed by the server for this user (and no event was passed in)
- @raise SSHException: if the authentication failed (and no event was
- passed in)
+ @raise AuthenticationException: if the authentication failed (and no
+ event was passed in)
+ @raise SSHException: if there was a network error
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to send the password unless we're on a secure link
@@ -993,7 +1147,7 @@ class Transport (threading.Thread):
return self.auth_handler.wait_for_response(my_event)
except BadAuthenticationType, x:
# if password auth isn't allowed, but keyboard-interactive *is*, try to fudge it
- if not fallback or not 'keyboard-interactive' in x.allowed_types:
+ if not fallback or ('keyboard-interactive' not in x.allowed_types):
raise
try:
def handler(title, instructions, fields):
@@ -1010,6 +1164,7 @@ class Transport (threading.Thread):
except SSHException, ignored:
# attempt failed; just raise the original exception
raise x
+ return None
def auth_publickey(self, username, key, event=None):
"""
@@ -1037,13 +1192,14 @@ class Transport (threading.Thread):
complete (whether it was successful or not)
@type event: threading.Event
@return: list of auth types permissible for the next stage of
- authentication (normally empty).
+ authentication (normally empty)
@rtype: list
@raise BadAuthenticationType: if public-key authentication isn't
- allowed by the server for this user (and no event was passed in).
- @raise SSHException: if the authentication failed (and no event was
- passed in).
+ allowed by the server for this user (and no event was passed in)
+ @raise AuthenticationException: if the authentication failed (and no
+ event was passed in)
+ @raise SSHException: if there was a network error
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
@@ -1100,7 +1256,8 @@ class Transport (threading.Thread):
@raise BadAuthenticationType: if public-key authentication isn't
allowed by the server for this user
- @raise SSHException: if the authentication failed
+ @raise AuthenticationException: if the authentication failed
+ @raise SSHException: if there was a network error
@since: 1.5
"""
@@ -1119,13 +1276,14 @@ class Transport (threading.Thread):
(See the C{logging} module for more info.) SSH Channels will log
to a sub-channel of the one specified.
- @param name: new channel name for logging.
+ @param name: new channel name for logging
@type name: str
@since: 1.1
"""
self.log_name = name
self.logger = util.get_logger(name)
+ self.packetizer.set_log(self.logger)
def get_log_channel(self):
"""
@@ -1166,8 +1324,7 @@ class Transport (threading.Thread):
"""
Turn on/off compression. This will only have an affect before starting
the transport (ie before calling L{connect}, etc). By default,
- compression is off since it negatively affects interactive sessions
- and is not fully tested.
+ compression is off since it negatively affects interactive sessions.
@param compress: C{True} to ask the remote client/server to compress
traffic; C{False} to refuse compression
@@ -1179,6 +1336,21 @@ class Transport (threading.Thread):
self._preferred_compression = ( 'zlib@openssh.com', 'zlib', 'none' )
else:
self._preferred_compression = ( 'none', )
+
+ def getpeername(self):
+ """
+ Return the address of the remote side of this Transport, if possible.
+ This is effectively a wrapper around C{'getpeername'} on the underlying
+ socket. If the socket-like object has no C{'getpeername'} method,
+ then C{("unknown", 0)} is returned.
+
+ @return: the address if the remote host, if known
+ @rtype: tuple(str, int)
+ """
+ gp = getattr(self.sock, 'getpeername', None)
+ if gp is None:
+ return ('unknown', 0)
+ return gp()
def stop_thread(self):
self.active = False
@@ -1188,25 +1360,29 @@ class Transport (threading.Thread):
### internals...
- def _log(self, level, msg):
+ def _log(self, level, msg, *args):
if issubclass(type(msg), list):
for m in msg:
self.logger.log(level, m)
else:
- self.logger.log(level, msg)
+ self.logger.log(level, msg, *args)
def _get_modulus_pack(self):
"used by KexGex to find primes for group exchange"
return self._modulus_pack
+ def _next_channel(self):
+ "you are holding the lock"
+ chanid = self._channel_counter
+ while self._channels.get(chanid) is not None:
+ self._channel_counter = (self._channel_counter + 1) & 0xffffff
+ chanid = self._channel_counter
+ self._channel_counter = (self._channel_counter + 1) & 0xffffff
+ return chanid
+
def _unlink_channel(self, chanid):
"used by a Channel to remove itself from the active channel list"
- try:
- self.lock.acquire()
- if self.channels.has_key(chanid):
- del self.channels[chanid]
- finally:
- self.lock.release()
+ self._channels.delete(chanid)
def _send_message(self, data):
self.packetizer.send_message(data)
@@ -1237,9 +1413,9 @@ class Transport (threading.Thread):
if self.session_id == None:
self.session_id = h
- def _expect_packet(self, type):
+ def _expect_packet(self, *ptypes):
"used by a kex object to register the next packet type it expects to see"
- self.expected_packet = type
+ self._expected_packet = tuple(ptypes)
def _verify_key(self, host_key, sig):
key = self._key_info[self.host_key_type](Message(host_key))
@@ -1262,16 +1438,34 @@ class Transport (threading.Thread):
m.add_mpint(self.K)
m.add_bytes(self.H)
m.add_bytes(sofar)
- hash = SHA.new(str(m)).digest()
- out += hash
- sofar += hash
+ digest = SHA.new(str(m)).digest()
+ out += digest
+ sofar += digest
return out[:nbytes]
def _get_cipher(self, name, key, iv):
- if not self._cipher_info.has_key(name):
+ if name not in self._cipher_info:
raise SSHException('Unknown client cipher ' + name)
return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv)
+ def _set_x11_handler(self, handler):
+ # only called if a channel has turned on x11 forwarding
+ if handler is None:
+ # by default, use the same mechanism as accept()
+ def default_handler(channel, (src_addr, src_port)):
+ self._queue_incoming_channel(channel)
+ self._x11_handler = default_handler
+ else:
+ self._x11_handler = handler
+
+ def _queue_incoming_channel(self, channel):
+ self.lock.acquire()
+ try:
+ self.server_accepts.append(channel)
+ self.server_accept_cv.notify()
+ finally:
+ self.lock.release()
+
def run(self):
# (use the exposed "run" method, because if we specify a thread target
# of a private method, threading.Thread will keep a reference to it
@@ -1288,7 +1482,7 @@ class Transport (threading.Thread):
self.packetizer.write_all(self.local_version + '\r\n')
self._check_banner()
self._send_kex_init()
- self.expected_packet = MSG_KEXINIT
+ self._expect_packet(MSG_KEXINIT)
while self.active:
if self.packetizer.need_rekey() and not self.in_kex:
@@ -1307,27 +1501,28 @@ class Transport (threading.Thread):
elif ptype == MSG_DEBUG:
self._parse_debug(m)
continue
- if self.expected_packet != 0:
- if ptype != self.expected_packet:
- raise SSHException('Expecting packet %d, got %d' % (self.expected_packet, ptype))
- self.expected_packet = 0
+ if len(self._expected_packet) > 0:
+ if ptype not in self._expected_packet:
+ raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype))
+ self._expected_packet = tuple()
if (ptype >= 30) and (ptype <= 39):
self.kex_engine.parse_next(ptype, m)
continue
- if self._handler_table.has_key(ptype):
+ if ptype in self._handler_table:
self._handler_table[ptype](self, m)
- elif self._channel_handler_table.has_key(ptype):
+ elif ptype in self._channel_handler_table:
chanid = m.get_int()
- if self.channels.has_key(chanid):
- self._channel_handler_table[ptype](self.channels[chanid], m)
- elif self.channels_seen.has_key(chanid):
+ chan = self._channels.get(chanid)
+ if chan is not None:
+ self._channel_handler_table[ptype](chan, m)
+ elif chanid in self.channels_seen:
self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid)
else:
self._log(ERROR, 'Channel request for unknown channel %d' % chanid)
self.active = False
self.packetizer.close()
- elif (self.auth_handler is not None) and self.auth_handler._handler_table.has_key(ptype):
+ elif (self.auth_handler is not None) and (ptype in self.auth_handler._handler_table):
self.auth_handler._handler_table[ptype](self.auth_handler, m)
else:
self._log(WARNING, 'Oops, unhandled type %d' % ptype)
@@ -1355,7 +1550,7 @@ class Transport (threading.Thread):
self._log(ERROR, util.tb_strings())
self.saved_exception = e
_active_threads.remove(self)
- for chan in self.channels.values():
+ for chan in self._channels.values():
chan._unlink()
if self.active:
self.active = False
@@ -1366,6 +1561,11 @@ class Transport (threading.Thread):
self.auth_handler.abort()
for event in self.channel_events.values():
event.set()
+ try:
+ self.lock.acquire()
+ self.server_accept_cv.notify()
+ finally:
+ self.lock.release()
self.sock.close()
@@ -1388,30 +1588,31 @@ class Transport (threading.Thread):
def _check_banner(self):
# this is slow, but we only have to do it once
for i in range(5):
- # give them 5 seconds for the first line, then just 2 seconds each additional line
+ # give them 15 seconds for the first line, then just 2 seconds
+ # each additional line. (some sites have very high latency.)
if i == 0:
- timeout = 5
+ timeout = self.banner_timeout
else:
timeout = 2
try:
- buffer = self.packetizer.readline(timeout)
+ buf = self.packetizer.readline(timeout)
except Exception, x:
raise SSHException('Error reading SSH protocol banner' + str(x))
- if buffer[:4] == 'SSH-':
+ if buf[:4] == 'SSH-':
break
- self._log(DEBUG, 'Banner: ' + buffer)
- if buffer[:4] != 'SSH-':
- raise SSHException('Indecipherable protocol version "' + buffer + '"')
+ self._log(DEBUG, 'Banner: ' + buf)
+ if buf[:4] != 'SSH-':
+ raise SSHException('Indecipherable protocol version "' + buf + '"')
# save this server version string for later
- self.remote_version = buffer
+ self.remote_version = buf
# pull off any attached comment
comment = ''
- i = string.find(buffer, ' ')
+ i = string.find(buf, ' ')
if i >= 0:
- comment = buffer[i+1:]
- buffer = buffer[:i]
+ comment = buf[i+1:]
+ buf = buf[:i]
# parse out version string and make sure it matches
- segs = buffer.split('-', 2)
+ segs = buf.split('-', 2)
if len(segs) < 3:
raise SSHException('Invalid SSH banner')
version = segs[1]
@@ -1612,7 +1813,7 @@ class Transport (threading.Thread):
if not self.packetizer.need_rekey():
self.in_kex = False
# we always expect to receive NEWKEYS now
- self.expected_packet = MSG_NEWKEYS
+ self._expect_packet(MSG_NEWKEYS)
def _auth_trigger(self):
self.authenticated = True
@@ -1661,7 +1862,22 @@ class Transport (threading.Thread):
kind = m.get_string()
self._log(DEBUG, 'Received global request "%s"' % kind)
want_reply = m.get_boolean()
- ok = self.server_object.check_global_request(kind, m)
+ if not self.server_mode:
+ self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind)
+ ok = False
+ elif kind == 'tcpip-forward':
+ address = m.get_string()
+ port = m.get_int()
+ ok = self.server_object.check_port_forward_request(address, port)
+ if ok != False:
+ ok = (ok,)
+ elif kind == 'cancel-tcpip-forward':
+ address = m.get_string()
+ port = m.get_int()
+ self.server_object.cancel_port_forward_request(address, port)
+ ok = True
+ else:
+ ok = self.server_object.check_global_request(kind, m)
extra = ()
if type(ok) is tuple:
extra = ok
@@ -1692,15 +1908,15 @@ class Transport (threading.Thread):
server_chanid = m.get_int()
server_window_size = m.get_int()
server_max_packet_size = m.get_int()
- if not self.channels.has_key(chanid):
+ chan = self._channels.get(chanid)
+ if chan is None:
self._log(WARNING, 'Success for unrequested channel! [??]')
return
self.lock.acquire()
try:
- chan = self.channels[chanid]
chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
self._log(INFO, 'Secsh channel %d opened.' % chanid)
- if self.channel_events.has_key(chanid):
+ if chanid in self.channel_events:
self.channel_events[chanid].set()
del self.channel_events[chanid]
finally:
@@ -1712,16 +1928,14 @@ class Transport (threading.Thread):
reason = m.get_int()
reason_str = m.get_string()
lang = m.get_string()
- if CONNECTION_FAILED_CODE.has_key(reason):
- reason_text = CONNECTION_FAILED_CODE[reason]
- else:
- reason_text = '(unknown code)'
+ reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
+ self.lock.acquire()
try:
- self.lock.aquire()
- if self.channels.has_key(chanid):
- del self.channels[chanid]
- if self.channel_events.has_key(chanid):
+ self.saved_exception = ChannelException(reason, reason_text)
+ if chanid in self.channel_events:
+ self._channels.delete(chanid)
+ if chanid in self.channel_events:
self.channel_events[chanid].set()
del self.channel_events[chanid]
finally:
@@ -1734,21 +1948,47 @@ class Transport (threading.Thread):
initial_window_size = m.get_int()
max_packet_size = m.get_int()
reject = False
- if not self.server_mode:
+ if (kind == 'x11') and (self._x11_handler is not None):
+ origin_addr = m.get_string()
+ origin_port = m.get_int()
+ self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port))
+ self.lock.acquire()
+ try:
+ my_chanid = self._next_channel()
+ finally:
+ self.lock.release()
+ elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None):
+ server_addr = m.get_string()
+ server_port = m.get_int()
+ origin_addr = m.get_string()
+ origin_port = m.get_int()
+ self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port))
+ self.lock.acquire()
+ try:
+ my_chanid = self._next_channel()
+ finally:
+ self.lock.release()
+ elif not self.server_mode:
self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
reject = True
reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
else:
self.lock.acquire()
try:
- my_chanid = self.channel_counter
- while self.channels.has_key(my_chanid):
- self.channel_counter = (self.channel_counter + 1) & 0xffffff
- my_chanid = self.channel_counter
- self.channel_counter = (self.channel_counter + 1) & 0xffffff
+ my_chanid = self._next_channel()
finally:
self.lock.release()
- reason = self.server_object.check_channel_request(kind, my_chanid)
+ if kind == 'direct-tcpip':
+ # handle direct-tcpip requests comming from the client
+ dest_addr = m.get_string()
+ dest_port = m.get_int()
+ origin_addr = m.get_string()
+ origin_port = m.get_int()
+ reason = self.server_object.check_channel_direct_tcpip_request(
+ my_chanid, (origin_addr, origin_port),
+ (dest_addr, dest_port))
+ else:
+ reason = self.server_object.check_channel_request(kind, my_chanid)
if reason != OPEN_SUCCEEDED:
self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind)
reject = True
@@ -1761,10 +2001,11 @@ class Transport (threading.Thread):
msg.add_string('en')
self._send_message(msg)
return
+
chan = Channel(my_chanid)
+ self.lock.acquire()
try:
- self.lock.acquire()
- self.channels[my_chanid] = chan
+ self._channels.put(my_chanid, chan)
self.channels_seen[my_chanid] = True
chan._set_transport(self)
chan._set_window(self.window_size, self.max_packet_size)
@@ -1778,13 +2019,14 @@ class Transport (threading.Thread):
m.add_int(self.window_size)
m.add_int(self.max_packet_size)
self._send_message(m)
- self._log(INFO, 'Secsh channel %d opened.' % my_chanid)
- try:
- self.lock.acquire()
- self.server_accepts.append(chan)
- self.server_accept_cv.notify()
- finally:
- self.lock.release()
+ self._log(INFO, 'Secsh channel %d (%s) opened.', my_chanid, kind)
+ if kind == 'x11':
+ self._x11_handler(chan, (origin_addr, origin_port))
+ elif kind == 'forwarded-tcpip':
+ chan.origin_addr = (origin_addr, origin_port)
+ self._tcp_handler(chan, (origin_addr, origin_port), (server_addr, server_port))
+ else:
+ self._queue_incoming_channel(chan)
def _parse_debug(self, m):
always_display = m.get_boolean()
@@ -1795,7 +2037,7 @@ class Transport (threading.Thread):
def _get_subsystem_handler(self, name):
try:
self.lock.acquire()
- if not self.subsystem_table.has_key(name):
+ if name not in self.subsystem_table:
return (None, [], {})
return self.subsystem_table[name]
finally: