summaryrefslogtreecommitdiff
path: root/factory
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2011-09-05 21:40:10 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2011-09-05 21:40:10 +0200
commit89631460a87d84790bc88766433e42f881937bf7 (patch)
tree00d00c1987a63ee1c74a8f5864c954d7cad5ff65 /factory
parent0d97937d994d4ec11f77661985be971a61daa6e3 (diff)
downloadfactory-boy-89631460a87d84790bc88766433e42f881937bf7.tar
factory-boy-89631460a87d84790bc88766433e42f881937bf7.tar.gz
Handle nested SubFactory with extra attributes.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
Diffstat (limited to 'factory')
-rw-r--r--factory/containers.py22
-rw-r--r--factory/test_base.py26
2 files changed, 43 insertions, 5 deletions
diff --git a/factory/containers.py b/factory/containers.py
index b46a19f..d3fbc5b 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -134,6 +134,21 @@ class DeclarationsHolder(object):
def items(self):
return list(self.iteritems())
+ def _extract_sub_fields(self, base):
+ """Extract all subfields declaration from a given dict-like object.
+
+ Will compare with attributes declared in the current object, and
+ will pop() values from the given base.
+ """
+ sub_fields = dict()
+
+ for key in list(base):
+ if ATTR_SPLITTER in key:
+ cls_name, attr_name = key.split(ATTR_SPLITTER, 1)
+ if cls_name in self:
+ sub_fields.setdefault(cls_name, {})[attr_name] = base.pop(key)
+ return sub_fields
+
def build_attributes(self, factory, create=False, extra=None):
"""Build the list of attributes based on class attributes."""
if not extra:
@@ -143,11 +158,8 @@ class DeclarationsHolder(object):
attributes = {}
sub_fields = {}
- for key in list(extra.keys()):
- if ATTR_SPLITTER in key:
- cls_name, attr_name = key.split(ATTR_SPLITTER, 1)
- if cls_name in self:
- sub_fields.setdefault(cls_name, {})[attr_name] = extra.pop(key)
+ for base in (self._unordered, self._ordered, extra):
+ sub_fields.update(self._extract_sub_fields(base))
# For fields in _unordered, use the value from extra if any; otherwise,
# use the default value.
diff --git a/factory/test_base.py b/factory/test_base.py
index 8772f8b..e28b5eb 100644
--- a/factory/test_base.py
+++ b/factory/test_base.py
@@ -326,6 +326,32 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
self.assertEqual('x0x', test_model.two.one)
self.assertEqual('x0xx0x', test_model.two.two)
+ def testNestedSubFactory(self):
+ """Test nested sub-factories."""
+
+ class TestObject(object):
+ def __init__(self, **kwargs):
+ for k, v in kwargs.iteritems():
+ setattr(self, k, v)
+
+ class TestObjectFactory(Factory):
+ FACTORY_FOR = TestObject
+
+ class WrappingTestObjectFactory(Factory):
+ FACTORY_FOR = TestObject
+
+ wrapped = declarations.SubFactory(TestObjectFactory)
+ wrapped_bis = declarations.SubFactory(TestObjectFactory, one=1)
+
+ class OuterWrappingTestObjectFactory(Factory):
+ FACTORY_FOR = TestObject
+
+ wrap = declarations.SubFactory(WrappingTestObjectFactory, wrapped__two=2)
+
+ outer = OuterWrappingTestObjectFactory.build()
+ self.assertEqual(outer.wrap.wrapped.two, 2)
+ self.assertEqual(outer.wrap.wrapped_bis.one, 1)
+
def testStubStrategy(self):
Factory.default_strategy = STUB_STRATEGY