diff options
Diffstat (limited to 'paramiko/util.py')
-rw-r--r-- | paramiko/util.py | 357 |
1 files changed, 357 insertions, 0 deletions
diff --git a/paramiko/util.py b/paramiko/util.py new file mode 100644 index 0000000..abab825 --- /dev/null +++ b/paramiko/util.py @@ -0,0 +1,357 @@ +# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Useful functions used by the rest of paramiko. +""" + +from __future__ import generators + +import fnmatch +import sys +import struct +import traceback +import threading + +from paramiko.common import * + + +# Change by RogerB - python < 2.3 doesn't have enumerate so we implement it +if sys.version_info < (2,3): + class enumerate: + def __init__ (self, sequence): + self.sequence = sequence + def __iter__ (self): + count = 0 + for item in self.sequence: + yield (count, item) + count += 1 + + +def inflate_long(s, always_positive=False): + "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" + out = 0L + negative = 0 + if not always_positive and (len(s) > 0) and (ord(s[0]) >= 0x80): + negative = 1 + if len(s) % 4: + filler = '\x00' + if negative: + filler = '\xff' + s = filler * (4 - len(s) % 4) + s + for i in range(0, len(s), 4): + out = (out << 32) + struct.unpack('>I', s[i:i+4])[0] + if negative: + out -= (1L << (8 * len(s))) + return out + +def deflate_long(n, add_sign_padding=True): + "turns a long-int into a normalized byte string (adapted from Crypto.Util.number)" + # after much testing, this algorithm was deemed to be the fastest + s = '' + n = long(n) + while (n != 0) and (n != -1): + s = struct.pack('>I', n & 0xffffffffL) + s + n = n >> 32 + # strip off leading zeros, FFs + for i in enumerate(s): + if (n == 0) and (i[1] != '\000'): + break + if (n == -1) and (i[1] != '\xff'): + break + else: + # degenerate case, n was either 0 or -1 + i = (0,) + if n == 0: + s = '\000' + else: + s = '\xff' + s = s[i[0]:] + if add_sign_padding: + if (n == 0) and (ord(s[0]) >= 0x80): + s = '\x00' + s + if (n == -1) and (ord(s[0]) < 0x80): + s = '\xff' + s + return s + +def format_binary_weird(data): + out = '' + for i in enumerate(data): + out += '%02X' % ord(i[1]) + if i[0] % 2: + out += ' ' + if i[0] % 16 == 15: + out += '\n' + return out + +def format_binary(data, prefix=''): + x = 0 + out = [] + while len(data) > x + 16: + out.append(format_binary_line(data[x:x+16])) + x += 16 + if x < len(data): + out.append(format_binary_line(data[x:])) + return [prefix + x for x in out] + +def format_binary_line(data): + left = ' '.join(['%02X' % ord(c) for c in data]) + right = ''.join([('.%c..' % c)[(ord(c)+63)//95] for c in data]) + return '%-50s %s' % (left, right) + +def hexify(s): + "turn a string into a hex sequence" + return ''.join(['%02X' % ord(c) for c in s]) + +def unhexify(s): + "turn a hex sequence back into a string" + return ''.join([chr(int(s[i:i+2], 16)) for i in range(0, len(s), 2)]) + +def safe_string(s): + out = '' + for c in s: + if (ord(c) >= 32) and (ord(c) <= 127): + out += c + else: + out += '%%%02X' % ord(c) + return out + +# ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s]) + +def bit_length(n): + norm = deflate_long(n, 0) + hbyte = ord(norm[0]) + bitlen = len(norm) * 8 + while not (hbyte & 0x80): + hbyte <<= 1 + bitlen -= 1 + return bitlen + +def tb_strings(): + return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') + +def generate_key_bytes(hashclass, salt, key, nbytes): + """ + Given a password, passphrase, or other human-source key, scramble it + through a secure hash into some keyworthy bytes. This specific algorithm + is used for encrypting/decrypting private key files. + + @param hashclass: class from L{Crypto.Hash} that can be used as a secure + hashing function (like C{MD5} or C{SHA}). + @type hashclass: L{Crypto.Hash} + @param salt: data to salt the hash with. + @type salt: string + @param key: human-entered password or passphrase. + @type key: string + @param nbytes: number of bytes to generate. + @type nbytes: int + @return: key data + @rtype: string + """ + keydata = '' + digest = '' + if len(salt) > 8: + salt = salt[:8] + while nbytes > 0: + hash = hashclass.new() + if len(digest) > 0: + hash.update(digest) + hash.update(key) + hash.update(salt) + digest = hash.digest() + size = min(nbytes, len(digest)) + keydata += digest[:size] + nbytes -= size + return keydata + +def load_host_keys(filename): + """ + Read a file of known SSH host keys, in the format used by openssh, and + return a compound dict of C{hostname -> keytype ->} L{PKey <paramiko.pkey.PKey>}. + The hostname may be an IP address or DNS name. The keytype will be either + C{"ssh-rsa"} or C{"ssh-dss"}. + + This type of file unfortunately doesn't exist on Windows, but on posix, + it will usually be stored in C{os.path.expanduser("~/.ssh/known_hosts")}. + + @param filename: name of the file to read host keys from + @type filename: str + @return: dict of host keys, indexed by hostname and then keytype + @rtype: dict(hostname, dict(keytype, L{PKey <paramiko.pkey.PKey>})) + """ + import base64 + from rsakey import RSAKey + from dsskey import DSSKey + + keys = {} + f = file(filename, 'r') + for line in f: + line = line.strip() + if (len(line) == 0) or (line[0] == '#'): + continue + keylist = line.split(' ') + if len(keylist) != 3: + continue + hostlist, keytype, key = keylist + hosts = hostlist.split(',') + for host in hosts: + if not keys.has_key(host): + keys[host] = {} + if keytype == 'ssh-rsa': + keys[host][keytype] = RSAKey(data=base64.decodestring(key)) + elif keytype == 'ssh-dss': + keys[host][keytype] = DSSKey(data=base64.decodestring(key)) + f.close() + return keys + +def parse_ssh_config(file_obj): + """ + Parse a config file of the format used by OpenSSH, and return an object + that can be used to make queries to L{lookup_ssh_host_config}. The + format is described in OpenSSH's C{ssh_config} man page. This method is + provided primarily as a convenience to posix users (since the OpenSSH + format is a de-facto standard on posix) but should work fine on Windows + too. + + The return value is currently a list of dictionaries, each containing + host-specific configuration, but this is considered an implementation + detail and may be subject to change in later versions. + + @param file_obj: a file-like object to read the config file from + @type file_obj: file + @return: opaque configuration object + @rtype: object + """ + ret = [] + config = { 'host': '*' } + ret.append(config) + + for line in file_obj: + line = line.rstrip('\n').lstrip() + if (line == '') or (line[0] == '#'): + continue + if '=' in line: + key, value = line.split('=', 1) + key = key.strip().lower() + else: + # find first whitespace, and split there + i = 0 + while (i < len(line)) and not line[i].isspace(): + i += 1 + if i == len(line): + raise Exception('Unparsable line: %r' % line) + key = line[:i].lower() + value = line[i:].lstrip() + + if key == 'host': + # do we have a pre-existing host config to append to? + matches = [c for c in ret if c['host'] == value] + if len(matches) > 0: + config = matches[0] + else: + config = { 'host': value } + ret.append(config) + else: + config[key] = value + + return ret + +def lookup_ssh_host_config(hostname, config): + """ + Return a dict of config options for a given hostname. The C{config} object + must come from L{parse_ssh_config}. + + The host-matching rules of OpenSSH's C{ssh_config} man page are used, which + means that all configuration options from matching host specifications are + merged, with more specific hostmasks taking precedence. In other words, if + C{"Port"} is set under C{"Host *"} and also C{"Host *.example.com"}, and + the lookup is for C{"ssh.example.com"}, then the port entry for + C{"Host *.example.com"} will win out. + + The keys in the returned dict are all normalized to lowercase (look for + C{"port"}, not C{"Port"}. No other processing is done to the keys or + values. + + @param hostname: the hostname to lookup + @type hostname: str + @param config: the config object to search + @type config: object + """ + matches = [x for x in config if fnmatch.fnmatch(hostname, x['host'])] + # sort in order of shortest match (usually '*') to longest + matches.sort(key=lambda x: len(x['host'])) + ret = {} + for m in matches: + ret.update(m) + del ret['host'] + return ret + +def mod_inverse(x, m): + # it's crazy how small python can make this function. + u1, u2, u3 = 1, 0, m + v1, v2, v3 = 0, 1, x + + while v3 > 0: + q = u3 // v3 + u1, v1 = v1, u1 - v1 * q + u2, v2 = v2, u2 - v2 * q + u3, v3 = v3, u3 - v3 * q + if u2 < 0: + u2 += m + return u2 + +_g_thread_ids = {} +_g_thread_counter = 0 +_g_thread_lock = threading.Lock() +def get_thread_id(): + global _g_thread_ids, _g_thread_counter, _g_thread_lock + tid = id(threading.currentThread()) + try: + return _g_thread_ids[tid] + except KeyError: + _g_thread_lock.acquire() + try: + _g_thread_counter += 1 + ret = _g_thread_ids[tid] = _g_thread_counter + finally: + _g_thread_lock.release() + return ret + +def log_to_file(filename, level=DEBUG): + "send paramiko logs to a logfile, if they're not already going somewhere" + l = logging.getLogger("paramiko") + if len(l.handlers) > 0: + return + l.setLevel(level) + f = open(filename, 'w') + lh = logging.StreamHandler(f) + lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s', + '%Y%m%d-%H:%M:%S')) + l.addHandler(lh) + +# make only one filter object, so it doesn't get applied more than once +class PFilter (object): + def filter(self, record): + record._threadid = get_thread_id() + return True +_pfilter = PFilter() + +def get_logger(name): + l = logging.getLogger(name) + l.addFilter(_pfilter) + return l |