diff options
Diffstat (limited to 'paramiko/kex_gex.py')
-rw-r--r-- | paramiko/kex_gex.py | 68 |
1 files changed, 55 insertions, 13 deletions
diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py index 994d76c..63a0c99 100644 --- a/paramiko/kex_gex.py +++ b/paramiko/kex_gex.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# Copyright (C) 2003-2007 Robey Pointer <robey@lag.net> # # This file is part of paramiko. # @@ -31,7 +31,8 @@ from paramiko.message import Message from paramiko.ssh_exception import SSHException -_MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(31, 35) +_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \ + _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35) class KexGex (object): @@ -43,19 +44,32 @@ class KexGex (object): def __init__(self, transport): self.transport = transport - - def start_kex(self): + self.p = None + self.q = None + self.g = None + self.x = None + self.e = None + self.f = None + self.old_style = False + + def start_kex(self, _test_old_style=False): if self.transport.server_mode: - self.transport._expect_packet(_MSG_KEXDH_GEX_REQUEST) + self.transport._expect_packet(_MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD) return # request a bit range: we accept (min_bits) to (max_bits), but prefer # (preferred_bits). according to the spec, we shouldn't pull the # minimum up above 1024. m = Message() - m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST)) - m.add_int(self.min_bits) - m.add_int(self.preferred_bits) - m.add_int(self.max_bits) + if _test_old_style: + # only used for unit tests: we shouldn't ever send this + m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST_OLD)) + m.add_int(self.preferred_bits) + self.old_style = True + else: + m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST)) + m.add_int(self.min_bits) + m.add_int(self.preferred_bits) + m.add_int(self.max_bits) self.transport._send_message(m) self.transport._expect_packet(_MSG_KEXDH_GEX_GROUP) @@ -68,6 +82,8 @@ class KexGex (object): return self._parse_kexdh_gex_init(m) elif ptype == _MSG_KEXDH_GEX_REPLY: return self._parse_kexdh_gex_reply(m) + elif ptype == _MSG_KEXDH_GEX_REQUEST_OLD: + return self._parse_kexdh_gex_request_old(m) raise SSHException('KexGex asked to handle packet type %d' % ptype) @@ -126,6 +142,28 @@ class KexGex (object): self.transport._send_message(m) self.transport._expect_packet(_MSG_KEXDH_GEX_INIT) + def _parse_kexdh_gex_request_old(self, m): + # same as above, but without min_bits or max_bits (used by older clients like putty) + self.preferred_bits = m.get_int() + # smoosh the user's preferred size into our own limits + if self.preferred_bits > self.max_bits: + self.preferred_bits = self.max_bits + if self.preferred_bits < self.min_bits: + self.preferred_bits = self.min_bits + # generate prime + pack = self.transport._get_modulus_pack() + if pack is None: + raise SSHException('Can\'t do server-side gex with no modulus pack') + self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,)) + self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits) + m = Message() + m.add_byte(chr(_MSG_KEXDH_GEX_GROUP)) + m.add_mpint(self.p) + m.add_mpint(self.g) + self.transport._send_message(m) + self.transport._expect_packet(_MSG_KEXDH_GEX_INIT) + self.old_style = True + def _parse_kexdh_gex_group(self, m): self.p = m.get_mpint() self.g = m.get_mpint() @@ -156,9 +194,11 @@ class KexGex (object): hm.add(self.transport.remote_version, self.transport.local_version, self.transport.remote_kex_init, self.transport.local_kex_init, key) - hm.add_int(self.min_bits) + if not self.old_style: + hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) - hm.add_int(self.max_bits) + if not self.old_style: + hm.add_int(self.max_bits) hm.add_mpint(self.p) hm.add_mpint(self.g) hm.add_mpint(self.e) @@ -189,9 +229,11 @@ class KexGex (object): hm.add(self.transport.local_version, self.transport.remote_version, self.transport.local_kex_init, self.transport.remote_kex_init, host_key) - hm.add_int(self.min_bits) + if not self.old_style: + hm.add_int(self.min_bits) hm.add_int(self.preferred_bits) - hm.add_int(self.max_bits) + if not self.old_style: + hm.add_int(self.max_bits) hm.add_mpint(self.p) hm.add_mpint(self.g) hm.add_mpint(self.e) |