summaryrefslogtreecommitdiff
path: root/paramiko/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'paramiko/util.py')
-rw-r--r--paramiko/util.py193
1 files changed, 102 insertions, 91 deletions
diff --git a/paramiko/util.py b/paramiko/util.py
index 85ee6b0..f4ee3ad 100644
--- a/paramiko/util.py
+++ b/paramiko/util.py
@@ -29,78 +29,65 @@ import sys
import struct
import traceback
import threading
+import logging
-from paramiko.common import *
+from paramiko.common import DEBUG, zero_byte, xffffffff, max_byte
+from paramiko.py3compat import PY2, long, byte_ord, b, byte_chr
from paramiko.config import SSHConfig
-# Change by RogerB - python < 2.3 doesn't have enumerate so we implement it
-if sys.version_info < (2,3):
- class enumerate:
- def __init__ (self, sequence):
- self.sequence = sequence
- def __iter__ (self):
- count = 0
- for item in self.sequence:
- yield (count, item)
- count += 1
-
-
def inflate_long(s, always_positive=False):
- "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
- out = 0L
+ """turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"""
+ out = long(0)
negative = 0
- if not always_positive and (len(s) > 0) and (ord(s[0]) >= 0x80):
+ if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80):
negative = 1
if len(s) % 4:
- filler = '\x00'
+ filler = zero_byte
if negative:
- filler = '\xff'
+ filler = max_byte
+ # never convert this to ``s +=`` because this is a string, not a number
+ # noinspection PyAugmentAssignment
s = filler * (4 - len(s) % 4) + s
for i in range(0, len(s), 4):
out = (out << 32) + struct.unpack('>I', s[i:i+4])[0]
if negative:
- out -= (1L << (8 * len(s)))
+ out -= (long(1) << (8 * len(s)))
return out
+deflate_zero = zero_byte if PY2 else 0
+deflate_ff = max_byte if PY2 else 0xff
+
+
def deflate_long(n, add_sign_padding=True):
- "turns a long-int into a normalized byte string (adapted from Crypto.Util.number)"
+ """turns a long-int into a normalized byte string (adapted from Crypto.Util.number)"""
# after much testing, this algorithm was deemed to be the fastest
- s = ''
+ s = bytes()
n = long(n)
while (n != 0) and (n != -1):
- s = struct.pack('>I', n & 0xffffffffL) + s
- n = n >> 32
+ s = struct.pack('>I', n & xffffffff) + s
+ n >>= 32
# strip off leading zeros, FFs
for i in enumerate(s):
- if (n == 0) and (i[1] != '\000'):
+ if (n == 0) and (i[1] != deflate_zero):
break
- if (n == -1) and (i[1] != '\xff'):
+ if (n == -1) and (i[1] != deflate_ff):
break
else:
# degenerate case, n was either 0 or -1
i = (0,)
if n == 0:
- s = '\000'
+ s = zero_byte
else:
- s = '\xff'
+ s = max_byte
s = s[i[0]:]
if add_sign_padding:
- if (n == 0) and (ord(s[0]) >= 0x80):
- s = '\x00' + s
- if (n == -1) and (ord(s[0]) < 0x80):
- s = '\xff' + s
+ if (n == 0) and (byte_ord(s[0]) >= 0x80):
+ s = zero_byte + s
+ if (n == -1) and (byte_ord(s[0]) < 0x80):
+ s = max_byte + s
return s
-def format_binary_weird(data):
- out = ''
- for i in enumerate(data):
- out += '%02X' % ord(i[1])
- if i[0] % 2:
- out += ' '
- if i[0] % 16 == 15:
- out += '\n'
- return out
def format_binary(data, prefix=''):
x = 0
@@ -112,69 +99,73 @@ def format_binary(data, prefix=''):
out.append(format_binary_line(data[x:]))
return [prefix + x for x in out]
+
def format_binary_line(data):
- left = ' '.join(['%02X' % ord(c) for c in data])
- right = ''.join([('.%c..' % c)[(ord(c)+63)//95] for c in data])
+ left = ' '.join(['%02X' % byte_ord(c) for c in data])
+ right = ''.join([('.%c..' % c)[(byte_ord(c)+63)//95] for c in data])
return '%-50s %s' % (left, right)
+
def hexify(s):
return hexlify(s).upper()
+
def unhexify(s):
return unhexlify(s)
+
def safe_string(s):
out = ''
for c in s:
- if (ord(c) >= 32) and (ord(c) <= 127):
+ if (byte_ord(c) >= 32) and (byte_ord(c) <= 127):
out += c
else:
- out += '%%%02X' % ord(c)
+ out += '%%%02X' % byte_ord(c)
return out
-# ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s])
def bit_length(n):
- norm = deflate_long(n, 0)
- hbyte = ord(norm[0])
- if hbyte == 0:
- return 1
- bitlen = len(norm) * 8
- while not (hbyte & 0x80):
- hbyte <<= 1
- bitlen -= 1
- return bitlen
+ try:
+ return n.bitlength()
+ except AttributeError:
+ norm = deflate_long(n, False)
+ hbyte = byte_ord(norm[0])
+ if hbyte == 0:
+ return 1
+ bitlen = len(norm) * 8
+ while not (hbyte & 0x80):
+ hbyte <<= 1
+ bitlen -= 1
+ return bitlen
+
def tb_strings():
return ''.join(traceback.format_exception(*sys.exc_info())).split('\n')
-def generate_key_bytes(hashclass, salt, key, nbytes):
+
+def generate_key_bytes(hash_alg, salt, key, nbytes):
"""
Given a password, passphrase, or other human-source key, scramble it
through a secure hash into some keyworthy bytes. This specific algorithm
is used for encrypting/decrypting private key files.
- @param hashclass: class from L{Crypto.Hash} that can be used as a secure
- hashing function (like C{MD5} or C{SHA}).
- @type hashclass: L{Crypto.Hash}
- @param salt: data to salt the hash with.
- @type salt: string
- @param key: human-entered password or passphrase.
- @type key: string
- @param nbytes: number of bytes to generate.
- @type nbytes: int
- @return: key data
- @rtype: string
+ :param function hash_alg: A function which creates a new hash object, such
+ as ``hashlib.sha256``.
+ :param salt: data to salt the hash with.
+ :type salt: byte string
+ :param str key: human-entered password or passphrase.
+ :param int nbytes: number of bytes to generate.
+ :return: Key data `str`
"""
- keydata = ''
- digest = ''
+ keydata = bytes()
+ digest = bytes()
if len(salt) > 8:
salt = salt[:8]
while nbytes > 0:
- hash_obj = hashclass.new()
+ hash_obj = hash_alg()
if len(digest) > 0:
hash_obj.update(digest)
- hash_obj.update(key)
+ hash_obj.update(b(key))
hash_obj.update(salt)
digest = hash_obj.digest()
size = min(nbytes, len(digest))
@@ -182,42 +173,45 @@ def generate_key_bytes(hashclass, salt, key, nbytes):
nbytes -= size
return keydata
+
def load_host_keys(filename):
"""
Read a file of known SSH host keys, in the format used by openssh, and
- return a compound dict of C{hostname -> keytype ->} L{PKey <paramiko.pkey.PKey>}.
- The hostname may be an IP address or DNS name. The keytype will be either
- C{"ssh-rsa"} or C{"ssh-dss"}.
+ return a compound dict of ``hostname -> keytype ->`` `PKey
+ <paramiko.pkey.PKey>`. The hostname may be an IP address or DNS name. The
+ keytype will be either ``"ssh-rsa"`` or ``"ssh-dss"``.
This type of file unfortunately doesn't exist on Windows, but on posix,
- it will usually be stored in C{os.path.expanduser("~/.ssh/known_hosts")}.
+ it will usually be stored in ``os.path.expanduser("~/.ssh/known_hosts")``.
- Since 1.5.3, this is just a wrapper around L{HostKeys}.
+ Since 1.5.3, this is just a wrapper around `.HostKeys`.
- @param filename: name of the file to read host keys from
- @type filename: str
- @return: dict of host keys, indexed by hostname and then keytype
- @rtype: dict(hostname, dict(keytype, L{PKey <paramiko.pkey.PKey>}))
+ :param str filename: name of the file to read host keys from
+ :return:
+ nested dict of `.PKey` objects, indexed by hostname and then keytype
"""
from paramiko.hostkeys import HostKeys
return HostKeys(filename)
+
def parse_ssh_config(file_obj):
"""
- Provided only as a backward-compatible wrapper around L{SSHConfig}.
+ Provided only as a backward-compatible wrapper around `.SSHConfig`.
"""
config = SSHConfig()
config.parse(file_obj)
return config
+
def lookup_ssh_host_config(hostname, config):
"""
- Provided only as a backward-compatible wrapper around L{SSHConfig}.
+ Provided only as a backward-compatible wrapper around `.SSHConfig`.
"""
return config.lookup(hostname)
+
def mod_inverse(x, m):
- # it's crazy how small python can make this function.
+ # it's crazy how small Python can make this function.
u1, u2, u3 = 1, 0, m
v1, v2, v3 = 0, 1, x
@@ -233,6 +227,8 @@ def mod_inverse(x, m):
_g_thread_ids = {}
_g_thread_counter = 0
_g_thread_lock = threading.Lock()
+
+
def get_thread_id():
global _g_thread_ids, _g_thread_counter, _g_thread_lock
tid = id(threading.currentThread())
@@ -247,8 +243,9 @@ def get_thread_id():
_g_thread_lock.release()
return ret
+
def log_to_file(filename, level=DEBUG):
- "send paramiko logs to a logfile, if they're not already going somewhere"
+ """send paramiko logs to a logfile, if they're not already going somewhere"""
l = logging.getLogger("paramiko")
if len(l.handlers) > 0:
return
@@ -259,6 +256,7 @@ def log_to_file(filename, level=DEBUG):
'%Y%m%d-%H:%M:%S'))
l.addHandler(lh)
+
# make only one filter object, so it doesn't get applied more than once
class PFilter (object):
def filter(self, record):
@@ -266,46 +264,59 @@ class PFilter (object):
return True
_pfilter = PFilter()
+
def get_logger(name):
l = logging.getLogger(name)
l.addFilter(_pfilter)
return l
+
def retry_on_signal(function):
"""Retries function until it doesn't raise an EINTR error"""
while True:
try:
return function()
- except EnvironmentError, e:
+ except EnvironmentError as e:
if e.errno != errno.EINTR:
raise
+
class Counter (object):
"""Stateful counter for CTR mode crypto"""
- def __init__(self, nbits, initial_value=1L, overflow=0L):
+ def __init__(self, nbits, initial_value=long(1), overflow=long(0)):
self.blocksize = nbits / 8
self.overflow = overflow
# start with value - 1 so we don't have to store intermediate values when counting
# could the iv be 0?
if initial_value == 0:
- self.value = array.array('c', '\xFF' * self.blocksize)
+ self.value = array.array('c', max_byte * self.blocksize)
else:
x = deflate_long(initial_value - 1, add_sign_padding=False)
- self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x)
+ self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x)
def __call__(self):
"""Increament the counter and return the new value"""
i = self.blocksize - 1
while i > -1:
- c = self.value[i] = chr((ord(self.value[i]) + 1) % 256)
- if c != '\x00':
+ c = self.value[i] = byte_chr((byte_ord(self.value[i]) + 1) % 256)
+ if c != zero_byte:
return self.value.tostring()
i -= 1
# counter reset
x = deflate_long(self.overflow, add_sign_padding=False)
- self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x)
+ self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x)
return self.value.tostring()
- def new(cls, nbits, initial_value=1L, overflow=0L):
+ def new(cls, nbits, initial_value=long(1), overflow=long(0)):
return cls(nbits, initial_value=initial_value, overflow=overflow)
new = classmethod(new)
+
+
+def constant_time_bytes_eq(a, b):
+ if len(a) != len(b):
+ return False
+ res = 0
+ # noinspection PyUnresolvedReferences
+ for i in (xrange if PY2 else range)(len(a)):
+ res |= byte_ord(a[i]) ^ byte_ord(b[i])
+ return res == 0