summaryrefslogtreecommitdiff
path: root/factory
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2012-04-09 01:10:42 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2012-04-09 01:10:42 +0200
commite80cbdc3224297ee57667e4000f1a671af05f520 (patch)
tree94275d5add1734ef6d03b0bb42d5af06bd758eba /factory
parent250ce5bdc8b6067a28351f5b3bb4c418d3f1e731 (diff)
downloadfactory-boy-e80cbdc3224297ee57667e4000f1a671af05f520.tar
factory-boy-e80cbdc3224297ee57667e4000f1a671af05f520.tar.gz
Move ATTR_SPLITTER logic to a dedicated module.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
Diffstat (limited to 'factory')
-rw-r--r--factory/containers.py19
-rw-r--r--factory/utils.py72
2 files changed, 74 insertions, 17 deletions
diff --git a/factory/containers.py b/factory/containers.py
index 2f92f62..b8557d6 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -22,11 +22,7 @@
from factory import declarations
-
-
-#: String for splitting an attribute name into a
-#: (subfactory_name, subfactory_field) tuple.
-ATTR_SPLITTER = '__'
+from factory import utils
class CyclicDefinitionError(Exception):
@@ -238,18 +234,7 @@ class AttributeBuilder(object):
self.factory = factory
self._containers = extra.pop('__containers', None)
self._attrs = factory.declarations(extra)
- self._subfields = self._extract_subfields()
-
- def _extract_subfields(self):
- """Extract the subfields from the declarations list."""
- sub_fields = {}
- for key in list(self._attrs):
- if ATTR_SPLITTER in key:
- # Trying to define a default of a subfactory
- cls_name, attr_name = key.split(ATTR_SPLITTER, 1)
- if cls_name in self._attrs:
- sub_fields.setdefault(cls_name, {})[attr_name] = self._attrs.pop(key)
- return sub_fields
+ self._subfields = utils.multi_extract_dict(self._attrs.keys(), self._attrs)
def build(self, create):
"""Build a dictionary of attributes.
diff --git a/factory/utils.py b/factory/utils.py
new file mode 100644
index 0000000..6c6fd7d
--- /dev/null
+++ b/factory/utils.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2010 Mark Sandstrom
+# Copyright (c) 2011 Raphaël Barrois
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+
+#: String for splitting an attribute name into a
+#: (subfactory_name, subfactory_field) tuple.
+ATTR_SPLITTER = '__'
+
+def extract_dict(prefix, kwargs, pop=True):
+ """Extracts all values beginning with a given prefix from a dict.
+
+ Can either 'pop' or 'get' them;
+
+ Args:
+ prefix (str): the prefix to use for lookups
+ kwargs (dict): the dict from which values should be extracted
+ pop (bool): whether to use pop (True) or get (False)
+
+ Returns:
+ A new dict, containing values from kwargs and beginning with
+ prefix + ATTR_SPLITTER. That full prefix is removed from the keys
+ of the returned dict.
+ """
+ prefix = prefix + ATTR_SPLITTER
+ extracted = {}
+ for key in kwargs.keys():
+ if key.startswith(prefix):
+ new_key = key[len(prefix):]
+ if pop:
+ value = kwargs.pop(key)
+ else:
+ value = kwargs[key]
+ extracted[new_key] = value
+ return extracted
+
+
+def declength_compare(a, b):
+ """Compare objects, choosing longest first."""
+ if len(a) > len(b):
+ return -1
+ elif len(a) < len(b):
+ return 1
+ else:
+ return cmp(a, b)
+
+
+def multi_extract_dict(prefixes, kwargs, pop=True):
+ """Extracts all values from a given list of prefixes."""
+ results = {}
+ for prefix in sorted(prefixes, cmp=declength_compare):
+ extracted = extract_dict(prefix, kwargs, pop=pop)
+ results[prefix] = extracted
+ return results