# Copyright (C) 2003-2005 Robey Pointer # # This file is part of paramiko. # # Paramiko is free software; you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation; either version 2.1 of the License, or (at your option) # any later version. # # Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more # details. # # You should have received a copy of the GNU Lesser General Public License # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. """ Useful functions used by the rest of paramiko. """ from __future__ import generators import fnmatch import sys import struct import traceback import threading from paramiko.common import * # Change by RogerB - python < 2.3 doesn't have enumerate so we implement it if sys.version_info < (2,3): class enumerate: def __init__ (self, sequence): self.sequence = sequence def __iter__ (self): count = 0 for item in self.sequence: yield (count, item) count += 1 def inflate_long(s, always_positive=False): "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" out = 0L negative = 0 if not always_positive and (len(s) > 0) and (ord(s[0]) >= 0x80): negative = 1 if len(s) % 4: filler = '\x00' if negative: filler = '\xff' s = filler * (4 - len(s) % 4) + s for i in range(0, len(s), 4): out = (out << 32) + struct.unpack('>I', s[i:i+4])[0] if negative: out -= (1L << (8 * len(s))) return out def deflate_long(n, add_sign_padding=True): "turns a long-int into a normalized byte string (adapted from Crypto.Util.number)" # after much testing, this algorithm was deemed to be the fastest s = '' n = long(n) while (n != 0) and (n != -1): s = struct.pack('>I', n & 0xffffffffL) + s n = n >> 32 # strip off leading zeros, FFs for i in enumerate(s): if (n == 0) and (i[1] != '\000'): break if (n == -1) and (i[1] != '\xff'): break else: # degenerate case, n was either 0 or -1 i = (0,) if n == 0: s = '\000' else: s = '\xff' s = s[i[0]:] if add_sign_padding: if (n == 0) and (ord(s[0]) >= 0x80): s = '\x00' + s if (n == -1) and (ord(s[0]) < 0x80): s = '\xff' + s return s def format_binary_weird(data): out = '' for i in enumerate(data): out += '%02X' % ord(i[1]) if i[0] % 2: out += ' ' if i[0] % 16 == 15: out += '\n' return out def format_binary(data, prefix=''): x = 0 out = [] while len(data) > x + 16: out.append(format_binary_line(data[x:x+16])) x += 16 if x < len(data): out.append(format_binary_line(data[x:])) return [prefix + x for x in out] def format_binary_line(data): left = ' '.join(['%02X' % ord(c) for c in data]) right = ''.join([('.%c..' % c)[(ord(c)+63)//95] for c in data]) return '%-50s %s' % (left, right) def hexify(s): "turn a string into a hex sequence" return ''.join(['%02X' % ord(c) for c in s]) def unhexify(s): "turn a hex sequence back into a string" return ''.join([chr(int(s[i:i+2], 16)) for i in range(0, len(s), 2)]) def safe_string(s): out = '' for c in s: if (ord(c) >= 32) and (ord(c) <= 127): out += c else: out += '%%%02X' % ord(c) return out # ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s]) def bit_length(n): norm = deflate_long(n, 0) hbyte = ord(norm[0]) bitlen = len(norm) * 8 while not (hbyte & 0x80): hbyte <<= 1 bitlen -= 1 return bitlen def tb_strings(): return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') def generate_key_bytes(hashclass, salt, key, nbytes): """ Given a password, passphrase, or other human-source key, scramble it through a secure hash into some keyworthy bytes. This specific algorithm is used for encrypting/decrypting private key files. @param hashclass: class from L{Crypto.Hash} that can be used as a secure hashing function (like C{MD5} or C{SHA}). @type hashclass: L{Crypto.Hash} @param salt: data to salt the hash with. @type salt: string @param key: human-entered password or passphrase. @type key: string @param nbytes: number of bytes to generate. @type nbytes: int @return: key data @rtype: string """ keydata = '' digest = '' if len(salt) > 8: salt = salt[:8] while nbytes > 0: hash = hashclass.new() if len(digest) > 0: hash.update(digest) hash.update(key) hash.update(salt) digest = hash.digest() size = min(nbytes, len(digest)) keydata += digest[:size] nbytes -= size return keydata def load_host_keys(filename): """ Read a file of known SSH host keys, in the format used by openssh, and return a compound dict of C{hostname -> keytype ->} L{PKey }. The hostname may be an IP address or DNS name. The keytype will be either C{"ssh-rsa"} or C{"ssh-dss"}. This type of file unfortunately doesn't exist on Windows, but on posix, it will usually be stored in C{os.path.expanduser("~/.ssh/known_hosts")}. @param filename: name of the file to read host keys from @type filename: str @return: dict of host keys, indexed by hostname and then keytype @rtype: dict(hostname, dict(keytype, L{PKey })) """ import base64 from rsakey import RSAKey from dsskey import DSSKey keys = {} f = file(filename, 'r') for line in f: line = line.strip() if (len(line) == 0) or (line[0] == '#'): continue keylist = line.split(' ') if len(keylist) != 3: continue hostlist, keytype, key = keylist hosts = hostlist.split(',') for host in hosts: if not keys.has_key(host): keys[host] = {} if keytype == 'ssh-rsa': keys[host][keytype] = RSAKey(data=base64.decodestring(key)) elif keytype == 'ssh-dss': keys[host][keytype] = DSSKey(data=base64.decodestring(key)) f.close() return keys def parse_ssh_config(file_obj): """ Parse a config file of the format used by OpenSSH, and return an object that can be used to make queries to L{lookup_ssh_host_config}. The format is described in OpenSSH's C{ssh_config} man page. This method is provided primarily as a convenience to posix users (since the OpenSSH format is a de-facto standard on posix) but should work fine on Windows too. The return value is currently a list of dictionaries, each containing host-specific configuration, but this is considered an implementation detail and may be subject to change in later versions. @param file_obj: a file-like object to read the config file from @type file_obj: file @return: opaque configuration object @rtype: object """ ret = [] config = { 'host': '*' } ret.append(config) for line in file_obj: line = line.rstrip('\n').lstrip() if (line == '') or (line[0] == '#'): continue if '=' in line: key, value = line.split('=', 1) key = key.strip().lower() else: # find first whitespace, and split there i = 0 while (i < len(line)) and not line[i].isspace(): i += 1 if i == len(line): raise Exception('Unparsable line: %r' % line) key = line[:i].lower() value = line[i:].lstrip() if key == 'host': # do we have a pre-existing host config to append to? matches = [c for c in ret if c['host'] == value] if len(matches) > 0: config = matches[0] else: config = { 'host': value } ret.append(config) else: config[key] = value return ret def lookup_ssh_host_config(hostname, config): """ Return a dict of config options for a given hostname. The C{config} object must come from L{parse_ssh_config}. The host-matching rules of OpenSSH's C{ssh_config} man page are used, which means that all configuration options from matching host specifications are merged, with more specific hostmasks taking precedence. In other words, if C{"Port"} is set under C{"Host *"} and also C{"Host *.example.com"}, and the lookup is for C{"ssh.example.com"}, then the port entry for C{"Host *.example.com"} will win out. The keys in the returned dict are all normalized to lowercase (look for C{"port"}, not C{"Port"}. No other processing is done to the keys or values. @param hostname: the hostname to lookup @type hostname: str @param config: the config object to search @type config: object """ matches = [x for x in config if fnmatch.fnmatch(hostname, x['host'])] # sort in order of shortest match (usually '*') to longest matches.sort(key=lambda x: len(x['host'])) ret = {} for m in matches: ret.update(m) del ret['host'] return ret def mod_inverse(x, m): # it's crazy how small python can make this function. u1, u2, u3 = 1, 0, m v1, v2, v3 = 0, 1, x while v3 > 0: q = u3 // v3 u1, v1 = v1, u1 - v1 * q u2, v2 = v2, u2 - v2 * q u3, v3 = v3, u3 - v3 * q if u2 < 0: u2 += m return u2 _g_thread_ids = {} _g_thread_counter = 0 _g_thread_lock = threading.Lock() def get_thread_id(): global _g_thread_ids, _g_thread_counter, _g_thread_lock tid = id(threading.currentThread()) try: return _g_thread_ids[tid] except KeyError: _g_thread_lock.acquire() try: _g_thread_counter += 1 ret = _g_thread_ids[tid] = _g_thread_counter finally: _g_thread_lock.release() return ret def log_to_file(filename, level=DEBUG): "send paramiko logs to a logfile, if they're not already going somewhere" l = logging.getLogger("paramiko") if len(l.handlers) > 0: return l.setLevel(level) f = open(filename, 'w') lh = logging.StreamHandler(f) lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s', '%Y%m%d-%H:%M:%S')) l.addHandler(lh) # make only one filter object, so it doesn't get applied more than once class PFilter (object): def filter(self, record): record._threadid = get_thread_id() return True _pfilter = PFilter() def get_logger(name): l = logging.getLogger(name) l.addFilter(_pfilter) return l