aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/containers.py19
-rw-r--r--factory/utils.py72
-rw-r--r--tests/__init__.py1
-rw-r--r--tests/test_utils.py209
4 files changed, 284 insertions, 17 deletions
diff --git a/factory/containers.py b/factory/containers.py
index 2f92f62..b8557d6 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -22,11 +22,7 @@
from factory import declarations
-
-
-#: String for splitting an attribute name into a
-#: (subfactory_name, subfactory_field) tuple.
-ATTR_SPLITTER = '__'
+from factory import utils
class CyclicDefinitionError(Exception):
@@ -238,18 +234,7 @@ class AttributeBuilder(object):
self.factory = factory
self._containers = extra.pop('__containers', None)
self._attrs = factory.declarations(extra)
- self._subfields = self._extract_subfields()
-
- def _extract_subfields(self):
- """Extract the subfields from the declarations list."""
- sub_fields = {}
- for key in list(self._attrs):
- if ATTR_SPLITTER in key:
- # Trying to define a default of a subfactory
- cls_name, attr_name = key.split(ATTR_SPLITTER, 1)
- if cls_name in self._attrs:
- sub_fields.setdefault(cls_name, {})[attr_name] = self._attrs.pop(key)
- return sub_fields
+ self._subfields = utils.multi_extract_dict(self._attrs.keys(), self._attrs)
def build(self, create):
"""Build a dictionary of attributes.
diff --git a/factory/utils.py b/factory/utils.py
new file mode 100644
index 0000000..6c6fd7d
--- /dev/null
+++ b/factory/utils.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2010 Mark Sandstrom
+# Copyright (c) 2011 Raphaël Barrois
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+
+#: String for splitting an attribute name into a
+#: (subfactory_name, subfactory_field) tuple.
+ATTR_SPLITTER = '__'
+
+def extract_dict(prefix, kwargs, pop=True):
+ """Extracts all values beginning with a given prefix from a dict.
+
+ Can either 'pop' or 'get' them;
+
+ Args:
+ prefix (str): the prefix to use for lookups
+ kwargs (dict): the dict from which values should be extracted
+ pop (bool): whether to use pop (True) or get (False)
+
+ Returns:
+ A new dict, containing values from kwargs and beginning with
+ prefix + ATTR_SPLITTER. That full prefix is removed from the keys
+ of the returned dict.
+ """
+ prefix = prefix + ATTR_SPLITTER
+ extracted = {}
+ for key in kwargs.keys():
+ if key.startswith(prefix):
+ new_key = key[len(prefix):]
+ if pop:
+ value = kwargs.pop(key)
+ else:
+ value = kwargs[key]
+ extracted[new_key] = value
+ return extracted
+
+
+def declength_compare(a, b):
+ """Compare objects, choosing longest first."""
+ if len(a) > len(b):
+ return -1
+ elif len(a) < len(b):
+ return 1
+ else:
+ return cmp(a, b)
+
+
+def multi_extract_dict(prefixes, kwargs, pop=True):
+ """Extracts all values from a given list of prefixes."""
+ results = {}
+ for prefix in sorted(prefixes, cmp=declength_compare):
+ extracted = extract_dict(prefix, kwargs, pop=pop)
+ results[prefix] = extracted
+ return results
diff --git a/tests/__init__.py b/tests/__init__.py
index 7ab3567..80a96a4 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -5,3 +5,4 @@ from .test_base import *
from .test_containers import *
from .test_declarations import *
from .test_using import *
+from .test_utils import *
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..543a6c0
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,209 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2010 Mark Sandstrom
+# Copyright (c) 2011 Raphaël Barrois
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+
+import unittest
+
+from factory import utils
+
+
+class DecLengthCompareTestCase(unittest.TestCase):
+ def test_reciprocity(self):
+ self.assertEqual(1, utils.declength_compare('a', 'bb'))
+ self.assertEqual(-1, utils.declength_compare('aa', 'b'))
+
+ def test_not_lexical(self):
+ self.assertEqual(1, utils.declength_compare('abc', 'aaaa'))
+ self.assertEqual(-1, utils.declength_compare('aaaa', 'abc'))
+
+ def test_same_length(self):
+ self.assertEqual(-1, utils.declength_compare('abc', 'abd'))
+ self.assertEqual(1, utils.declength_compare('abe', 'abd'))
+
+ def test_equality(self):
+ self.assertEqual(0, utils.declength_compare('abc', 'abc'))
+ self.assertEqual(0, utils.declength_compare([1, 2, 3], [1, 2, 3]))
+
+
+class ExtractDictTestCase(unittest.TestCase):
+ def test_empty_dict(self):
+ self.assertEqual({}, utils.extract_dict('foo', {}))
+
+ def test_unused_key(self):
+ self.assertEqual({}, utils.extract_dict('foo', {'bar__baz': 42}))
+
+ def test_empty_key(self):
+ self.assertEqual({}, utils.extract_dict('', {'foo': 13, 'bar__baz': 42}))
+ d = {'foo': 13, 'bar__baz': 42, '__foo': 1}
+ self.assertEqual({'foo': 1}, utils.extract_dict('', d))
+ self.assertNotIn('__foo', d)
+
+ def test_one_key(self):
+ d = {'foo': 13, 'foo__baz': 42, '__foo': 1}
+ self.assertEqual({'baz': 42}, utils.extract_dict('foo', d, pop=False))
+ self.assertEqual(42, d['foo__baz'])
+
+ self.assertEqual({'baz': 42}, utils.extract_dict('foo', d, pop=True))
+ self.assertNotIn('foo__baz', d)
+
+ def test_many_key(self):
+ d = {'foo': 13, 'foo__baz': 42, 'foo__foo__bar': 2, 'foo__bar': 3, '__foo': 1}
+ self.assertEqual({'foo__bar': 2, 'bar': 3, 'baz': 42},
+ utils.extract_dict('foo', d, pop=False))
+ self.assertEqual(42, d['foo__baz'])
+ self.assertEqual(3, d['foo__bar'])
+ self.assertEqual(2, d['foo__foo__bar'])
+
+ self.assertEqual({'foo__bar': 2, 'bar': 3, 'baz': 42},
+ utils.extract_dict('foo', d, pop=True))
+ self.assertNotIn('foo__baz', d)
+ self.assertNotIn('foo__bar', d)
+ self.assertNotIn('foo__foo__bar', d)
+
+class MultiExtractDictTestCase(unittest.TestCase):
+ def test_empty_dict(self):
+ self.assertEqual({'foo': {}}, utils.multi_extract_dict(['foo'], {}))
+
+ def test_unused_key(self):
+ self.assertEqual({'foo': {}},
+ utils.multi_extract_dict(['foo'], {'bar__baz': 42}))
+ self.assertEqual({'foo': {}, 'baz': {}},
+ utils.multi_extract_dict(['foo', 'baz'], {'bar__baz': 42}))
+
+ def test_no_key(self):
+ self.assertEqual({}, utils.multi_extract_dict([], {'bar__baz': 42}))
+
+ def test_empty_key(self):
+ self.assertEqual({'': {}},
+ utils.multi_extract_dict([''], {'foo': 13, 'bar__baz': 42}))
+
+ d = {'foo': 13, 'bar__baz': 42, '__foo': 1}
+ self.assertEqual({'': {'foo': 1}},
+ utils.multi_extract_dict([''], d))
+ self.assertNotIn('__foo', d)
+
+ def test_one_extracted(self):
+ d = {'foo': 13, 'foo__baz': 42, '__foo': 1}
+ self.assertEqual({'foo': {'baz': 42}},
+ utils.multi_extract_dict(['foo'], d, pop=False))
+ self.assertEqual(42, d['foo__baz'])
+
+ self.assertEqual({'foo': {'baz': 42}},
+ utils.multi_extract_dict(['foo'], d, pop=True))
+ self.assertNotIn('foo__baz', d)
+
+ def test_many_extracted(self):
+ d = {'foo': 13, 'foo__baz': 42, 'foo__foo__bar': 2, 'foo__bar': 3, '__foo': 1}
+ self.assertEqual({'foo': {'foo__bar': 2, 'bar': 3, 'baz': 42}},
+ utils.multi_extract_dict(['foo'], d, pop=False))
+ self.assertEqual(42, d['foo__baz'])
+ self.assertEqual(3, d['foo__bar'])
+ self.assertEqual(2, d['foo__foo__bar'])
+
+ self.assertEqual({'foo': {'foo__bar': 2, 'bar': 3, 'baz': 42}},
+ utils.multi_extract_dict(['foo'], d, pop=True))
+ self.assertNotIn('foo__baz', d)
+ self.assertNotIn('foo__bar', d)
+ self.assertNotIn('foo__foo__bar', d)
+
+ def test_many_keys_one_extracted(self):
+ d = {'foo': 13, 'foo__baz': 42, '__foo': 1}
+ self.assertEqual({'foo': {'baz': 42}, 'baz': {}},
+ utils.multi_extract_dict(['foo', 'baz'], d, pop=False))
+ self.assertEqual(42, d['foo__baz'])
+
+ self.assertEqual({'foo': {'baz': 42}, 'baz': {}},
+ utils.multi_extract_dict(['foo', 'baz'], d, pop=True))
+ self.assertNotIn('foo__baz', d)
+
+ def test_many_keys_many_extracted(self):
+ d = {
+ 'foo': 13,
+ 'foo__baz': 42, 'foo__foo__bar': 2, 'foo__bar': 3,
+ 'bar__foo': 1, 'bar__bar__baz': 4,
+ }
+
+ self.assertEqual(
+ {
+ 'foo': {'foo__bar': 2, 'bar': 3, 'baz': 42},
+ 'bar': {'foo': 1, 'bar__baz': 4},
+ 'baz': {}
+ },
+ utils.multi_extract_dict(['foo', 'bar', 'baz'], d, pop=False))
+ self.assertEqual(42, d['foo__baz'])
+ self.assertEqual(3, d['foo__bar'])
+ self.assertEqual(2, d['foo__foo__bar'])
+ self.assertEqual(1, d['bar__foo'])
+ self.assertEqual(4, d['bar__bar__baz'])
+
+ self.assertEqual(
+ {
+ 'foo': {'foo__bar': 2, 'bar': 3, 'baz': 42},
+ 'bar': {'foo': 1, 'bar__baz': 4},
+ 'baz': {}
+ },
+ utils.multi_extract_dict(['foo', 'bar', 'baz'], d, pop=True))
+ self.assertNotIn('foo__baz', d)
+ self.assertNotIn('foo__bar', d)
+ self.assertNotIn('foo__foo__bar', d)
+ self.assertNotIn('bar__foo', d)
+ self.assertNotIn('bar__bar__baz', d)
+
+ def test_son_in_list(self):
+ """Make sure that prefixes are used in decreasing match length order."""
+ d = {
+ 'foo': 13,
+ 'foo__baz': 42, 'foo__foo__bar': 2, 'foo__bar': 3,
+ 'bar__foo': 1, 'bar__bar__baz': 4,
+ }
+
+ self.assertEqual(
+ {
+ 'foo__foo': {'bar': 2},
+ 'foo': {'bar': 3, 'baz': 42, 'foo__bar': 2},
+ 'bar__bar': {'baz': 4},
+ 'bar': {'foo': 1, 'bar__baz': 4},
+ 'baz': {}
+ },
+ utils.multi_extract_dict(
+ ['foo', 'bar', 'baz', 'foo__foo', 'bar__bar'], d, pop=False))
+ self.assertEqual(42, d['foo__baz'])
+ self.assertEqual(3, d['foo__bar'])
+ self.assertEqual(2, d['foo__foo__bar'])
+ self.assertEqual(1, d['bar__foo'])
+ self.assertEqual(4, d['bar__bar__baz'])
+
+ self.assertEqual(
+ {
+ 'foo__foo': {'bar': 2},
+ 'foo': {'bar': 3, 'baz': 42},
+ 'bar__bar': {'baz': 4},
+ 'bar': {'foo': 1},
+ 'baz': {}
+ },
+ utils.multi_extract_dict(
+ ['foo', 'bar', 'baz', 'foo__foo', 'bar__bar'], d, pop=True))
+ self.assertNotIn('foo__baz', d)
+ self.assertNotIn('foo__bar', d)
+ self.assertNotIn('foo__foo__bar', d)
+ self.assertNotIn('bar__foo', d)
+ self.assertNotIn('bar__bar__baz', d)