summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/utils.py23
-rw-r--r--tests/test_utils.py30
2 files changed, 46 insertions, 7 deletions
diff --git a/factory/utils.py b/factory/utils.py
index 6c6fd7d..2fcd7ff 100644
--- a/factory/utils.py
+++ b/factory/utils.py
@@ -25,7 +25,7 @@
#: (subfactory_name, subfactory_field) tuple.
ATTR_SPLITTER = '__'
-def extract_dict(prefix, kwargs, pop=True):
+def extract_dict(prefix, kwargs, pop=True, exclude=()):
"""Extracts all values beginning with a given prefix from a dict.
Can either 'pop' or 'get' them;
@@ -34,6 +34,7 @@ def extract_dict(prefix, kwargs, pop=True):
prefix (str): the prefix to use for lookups
kwargs (dict): the dict from which values should be extracted
pop (bool): whether to use pop (True) or get (False)
+ exclude (iterable): list of prefixed keys that shouldn't be extracted
Returns:
A new dict, containing values from kwargs and beginning with
@@ -43,6 +44,9 @@ def extract_dict(prefix, kwargs, pop=True):
prefix = prefix + ATTR_SPLITTER
extracted = {}
for key in kwargs.keys():
+ if key in exclude:
+ continue
+
if key.startswith(prefix):
new_key = key[len(prefix):]
if pop:
@@ -63,10 +67,21 @@ def declength_compare(a, b):
return cmp(a, b)
-def multi_extract_dict(prefixes, kwargs, pop=True):
- """Extracts all values from a given list of prefixes."""
+def multi_extract_dict(prefixes, kwargs, pop=True, exclude=()):
+ """Extracts all values from a given list of prefixes.
+
+ Arguments have the same meaning as for extract_dict.
+
+ Returns:
+ dict(str => dict): a dict mapping each prefix to the dict of extracted
+ key/value.
+ """
results = {}
+ exclude = list(exclude)
for prefix in sorted(prefixes, cmp=declength_compare):
- extracted = extract_dict(prefix, kwargs, pop=pop)
+ extracted = extract_dict(prefix, kwargs, pop=pop, exclude=exclude)
results[prefix] = extracted
+ exclude.extend(
+ ['%s%s%s' % (prefix, ATTR_SPLITTER, key) for key in extracted])
+
return results
diff --git a/tests/test_utils.py b/tests/test_utils.py
index c3047d3..2c77c15 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -68,7 +68,17 @@ class ExtractDictTestCase(unittest.TestCase):
self.assertEqual({'baz': 42}, utils.extract_dict('foo', d, pop=True))
self.assertNotIn('foo__baz', d)
- def test_many_key(self):
+ def test_one_key_excluded(self):
+ d = {'foo': 13, 'foo__baz': 42, '__foo': 1}
+ self.assertEqual({},
+ utils.extract_dict('foo', d, pop=False, exclude=('foo__baz',)))
+ self.assertEqual(42, d['foo__baz'])
+
+ self.assertEqual({},
+ utils.extract_dict('foo', d, pop=True, exclude=('foo__baz',)))
+ self.assertIn('foo__baz', d)
+
+ def test_many_keys(self):
d = {'foo': 13, 'foo__baz': 42, 'foo__foo__bar': 2, 'foo__bar': 3, '__foo': 1}
self.assertEqual({'foo__bar': 2, 'bar': 3, 'baz': 42},
utils.extract_dict('foo', d, pop=False))
@@ -82,6 +92,20 @@ class ExtractDictTestCase(unittest.TestCase):
self.assertNotIn('foo__bar', d)
self.assertNotIn('foo__foo__bar', d)
+ def test_many_keys_excluded(self):
+ d = {'foo': 13, 'foo__baz': 42, 'foo__foo__bar': 2, 'foo__bar': 3, '__foo': 1}
+ self.assertEqual({'foo__bar': 2, 'baz': 42},
+ utils.extract_dict('foo', d, pop=False, exclude=('foo__bar', 'bar')))
+ self.assertEqual(42, d['foo__baz'])
+ self.assertEqual(3, d['foo__bar'])
+ self.assertEqual(2, d['foo__foo__bar'])
+
+ self.assertEqual({'foo__bar': 2, 'baz': 42},
+ utils.extract_dict('foo', d, pop=True, exclude=('foo__bar', 'bar')))
+ self.assertNotIn('foo__baz', d)
+ self.assertIn('foo__bar', d)
+ self.assertNotIn('foo__foo__bar', d)
+
class MultiExtractDictTestCase(unittest.TestCase):
def test_empty_dict(self):
self.assertEqual({'foo': {}}, utils.multi_extract_dict(['foo'], {}))
@@ -182,9 +206,9 @@ class MultiExtractDictTestCase(unittest.TestCase):
self.assertEqual(
{
'foo__foo': {'bar': 2},
- 'foo': {'bar': 3, 'baz': 42, 'foo__bar': 2},
+ 'foo': {'bar': 3, 'baz': 42},
'bar__bar': {'baz': 4},
- 'bar': {'foo': 1, 'bar__baz': 4},
+ 'bar': {'foo': 1},
'baz': {}
},
utils.multi_extract_dict(