diff options
Diffstat (limited to 'paramiko/util.py')
-rw-r--r-- | paramiko/util.py | 193 |
1 files changed, 102 insertions, 91 deletions
diff --git a/paramiko/util.py b/paramiko/util.py index 85ee6b0..f4ee3ad 100644 --- a/paramiko/util.py +++ b/paramiko/util.py @@ -29,78 +29,65 @@ import sys import struct import traceback import threading +import logging -from paramiko.common import * +from paramiko.common import DEBUG, zero_byte, xffffffff, max_byte +from paramiko.py3compat import PY2, long, byte_ord, b, byte_chr from paramiko.config import SSHConfig -# Change by RogerB - python < 2.3 doesn't have enumerate so we implement it -if sys.version_info < (2,3): - class enumerate: - def __init__ (self, sequence): - self.sequence = sequence - def __iter__ (self): - count = 0 - for item in self.sequence: - yield (count, item) - count += 1 - - def inflate_long(s, always_positive=False): - "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" - out = 0L + """turns a normalized byte string into a long-int (adapted from Crypto.Util.number)""" + out = long(0) negative = 0 - if not always_positive and (len(s) > 0) and (ord(s[0]) >= 0x80): + if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80): negative = 1 if len(s) % 4: - filler = '\x00' + filler = zero_byte if negative: - filler = '\xff' + filler = max_byte + # never convert this to ``s +=`` because this is a string, not a number + # noinspection PyAugmentAssignment s = filler * (4 - len(s) % 4) + s for i in range(0, len(s), 4): out = (out << 32) + struct.unpack('>I', s[i:i+4])[0] if negative: - out -= (1L << (8 * len(s))) + out -= (long(1) << (8 * len(s))) return out +deflate_zero = zero_byte if PY2 else 0 +deflate_ff = max_byte if PY2 else 0xff + + def deflate_long(n, add_sign_padding=True): - "turns a long-int into a normalized byte string (adapted from Crypto.Util.number)" + """turns a long-int into a normalized byte string (adapted from Crypto.Util.number)""" # after much testing, this algorithm was deemed to be the fastest - s = '' + s = bytes() n = long(n) while (n != 0) and (n != -1): - s = struct.pack('>I', n & 0xffffffffL) + s - n = n >> 32 + s = struct.pack('>I', n & xffffffff) + s + n >>= 32 # strip off leading zeros, FFs for i in enumerate(s): - if (n == 0) and (i[1] != '\000'): + if (n == 0) and (i[1] != deflate_zero): break - if (n == -1) and (i[1] != '\xff'): + if (n == -1) and (i[1] != deflate_ff): break else: # degenerate case, n was either 0 or -1 i = (0,) if n == 0: - s = '\000' + s = zero_byte else: - s = '\xff' + s = max_byte s = s[i[0]:] if add_sign_padding: - if (n == 0) and (ord(s[0]) >= 0x80): - s = '\x00' + s - if (n == -1) and (ord(s[0]) < 0x80): - s = '\xff' + s + if (n == 0) and (byte_ord(s[0]) >= 0x80): + s = zero_byte + s + if (n == -1) and (byte_ord(s[0]) < 0x80): + s = max_byte + s return s -def format_binary_weird(data): - out = '' - for i in enumerate(data): - out += '%02X' % ord(i[1]) - if i[0] % 2: - out += ' ' - if i[0] % 16 == 15: - out += '\n' - return out def format_binary(data, prefix=''): x = 0 @@ -112,69 +99,73 @@ def format_binary(data, prefix=''): out.append(format_binary_line(data[x:])) return [prefix + x for x in out] + def format_binary_line(data): - left = ' '.join(['%02X' % ord(c) for c in data]) - right = ''.join([('.%c..' % c)[(ord(c)+63)//95] for c in data]) + left = ' '.join(['%02X' % byte_ord(c) for c in data]) + right = ''.join([('.%c..' % c)[(byte_ord(c)+63)//95] for c in data]) return '%-50s %s' % (left, right) + def hexify(s): return hexlify(s).upper() + def unhexify(s): return unhexlify(s) + def safe_string(s): out = '' for c in s: - if (ord(c) >= 32) and (ord(c) <= 127): + if (byte_ord(c) >= 32) and (byte_ord(c) <= 127): out += c else: - out += '%%%02X' % ord(c) + out += '%%%02X' % byte_ord(c) return out -# ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s]) def bit_length(n): - norm = deflate_long(n, 0) - hbyte = ord(norm[0]) - if hbyte == 0: - return 1 - bitlen = len(norm) * 8 - while not (hbyte & 0x80): - hbyte <<= 1 - bitlen -= 1 - return bitlen + try: + return n.bitlength() + except AttributeError: + norm = deflate_long(n, False) + hbyte = byte_ord(norm[0]) + if hbyte == 0: + return 1 + bitlen = len(norm) * 8 + while not (hbyte & 0x80): + hbyte <<= 1 + bitlen -= 1 + return bitlen + def tb_strings(): return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') -def generate_key_bytes(hashclass, salt, key, nbytes): + +def generate_key_bytes(hash_alg, salt, key, nbytes): """ Given a password, passphrase, or other human-source key, scramble it through a secure hash into some keyworthy bytes. This specific algorithm is used for encrypting/decrypting private key files. - @param hashclass: class from L{Crypto.Hash} that can be used as a secure - hashing function (like C{MD5} or C{SHA}). - @type hashclass: L{Crypto.Hash} - @param salt: data to salt the hash with. - @type salt: string - @param key: human-entered password or passphrase. - @type key: string - @param nbytes: number of bytes to generate. - @type nbytes: int - @return: key data - @rtype: string + :param function hash_alg: A function which creates a new hash object, such + as ``hashlib.sha256``. + :param salt: data to salt the hash with. + :type salt: byte string + :param str key: human-entered password or passphrase. + :param int nbytes: number of bytes to generate. + :return: Key data `str` """ - keydata = '' - digest = '' + keydata = bytes() + digest = bytes() if len(salt) > 8: salt = salt[:8] while nbytes > 0: - hash_obj = hashclass.new() + hash_obj = hash_alg() if len(digest) > 0: hash_obj.update(digest) - hash_obj.update(key) + hash_obj.update(b(key)) hash_obj.update(salt) digest = hash_obj.digest() size = min(nbytes, len(digest)) @@ -182,42 +173,45 @@ def generate_key_bytes(hashclass, salt, key, nbytes): nbytes -= size return keydata + def load_host_keys(filename): """ Read a file of known SSH host keys, in the format used by openssh, and - return a compound dict of C{hostname -> keytype ->} L{PKey <paramiko.pkey.PKey>}. - The hostname may be an IP address or DNS name. The keytype will be either - C{"ssh-rsa"} or C{"ssh-dss"}. + return a compound dict of ``hostname -> keytype ->`` `PKey + <paramiko.pkey.PKey>`. The hostname may be an IP address or DNS name. The + keytype will be either ``"ssh-rsa"`` or ``"ssh-dss"``. This type of file unfortunately doesn't exist on Windows, but on posix, - it will usually be stored in C{os.path.expanduser("~/.ssh/known_hosts")}. + it will usually be stored in ``os.path.expanduser("~/.ssh/known_hosts")``. - Since 1.5.3, this is just a wrapper around L{HostKeys}. + Since 1.5.3, this is just a wrapper around `.HostKeys`. - @param filename: name of the file to read host keys from - @type filename: str - @return: dict of host keys, indexed by hostname and then keytype - @rtype: dict(hostname, dict(keytype, L{PKey <paramiko.pkey.PKey>})) + :param str filename: name of the file to read host keys from + :return: + nested dict of `.PKey` objects, indexed by hostname and then keytype """ from paramiko.hostkeys import HostKeys return HostKeys(filename) + def parse_ssh_config(file_obj): """ - Provided only as a backward-compatible wrapper around L{SSHConfig}. + Provided only as a backward-compatible wrapper around `.SSHConfig`. """ config = SSHConfig() config.parse(file_obj) return config + def lookup_ssh_host_config(hostname, config): """ - Provided only as a backward-compatible wrapper around L{SSHConfig}. + Provided only as a backward-compatible wrapper around `.SSHConfig`. """ return config.lookup(hostname) + def mod_inverse(x, m): - # it's crazy how small python can make this function. + # it's crazy how small Python can make this function. u1, u2, u3 = 1, 0, m v1, v2, v3 = 0, 1, x @@ -233,6 +227,8 @@ def mod_inverse(x, m): _g_thread_ids = {} _g_thread_counter = 0 _g_thread_lock = threading.Lock() + + def get_thread_id(): global _g_thread_ids, _g_thread_counter, _g_thread_lock tid = id(threading.currentThread()) @@ -247,8 +243,9 @@ def get_thread_id(): _g_thread_lock.release() return ret + def log_to_file(filename, level=DEBUG): - "send paramiko logs to a logfile, if they're not already going somewhere" + """send paramiko logs to a logfile, if they're not already going somewhere""" l = logging.getLogger("paramiko") if len(l.handlers) > 0: return @@ -259,6 +256,7 @@ def log_to_file(filename, level=DEBUG): '%Y%m%d-%H:%M:%S')) l.addHandler(lh) + # make only one filter object, so it doesn't get applied more than once class PFilter (object): def filter(self, record): @@ -266,46 +264,59 @@ class PFilter (object): return True _pfilter = PFilter() + def get_logger(name): l = logging.getLogger(name) l.addFilter(_pfilter) return l + def retry_on_signal(function): """Retries function until it doesn't raise an EINTR error""" while True: try: return function() - except EnvironmentError, e: + except EnvironmentError as e: if e.errno != errno.EINTR: raise + class Counter (object): """Stateful counter for CTR mode crypto""" - def __init__(self, nbits, initial_value=1L, overflow=0L): + def __init__(self, nbits, initial_value=long(1), overflow=long(0)): self.blocksize = nbits / 8 self.overflow = overflow # start with value - 1 so we don't have to store intermediate values when counting # could the iv be 0? if initial_value == 0: - self.value = array.array('c', '\xFF' * self.blocksize) + self.value = array.array('c', max_byte * self.blocksize) else: x = deflate_long(initial_value - 1, add_sign_padding=False) - self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x) + self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x) def __call__(self): """Increament the counter and return the new value""" i = self.blocksize - 1 while i > -1: - c = self.value[i] = chr((ord(self.value[i]) + 1) % 256) - if c != '\x00': + c = self.value[i] = byte_chr((byte_ord(self.value[i]) + 1) % 256) + if c != zero_byte: return self.value.tostring() i -= 1 # counter reset x = deflate_long(self.overflow, add_sign_padding=False) - self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x) + self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x) return self.value.tostring() - def new(cls, nbits, initial_value=1L, overflow=0L): + def new(cls, nbits, initial_value=long(1), overflow=long(0)): return cls(nbits, initial_value=initial_value, overflow=overflow) new = classmethod(new) + + +def constant_time_bytes_eq(a, b): + if len(a) != len(b): + return False + res = 0 + # noinspection PyUnresolvedReferences + for i in (xrange if PY2 else range)(len(a)): + res |= byte_ord(a[i]) ^ byte_ord(b[i]) + return res == 0 |