diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2011-09-05 21:40:10 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2011-09-05 21:40:10 +0200 |
commit | 89631460a87d84790bc88766433e42f881937bf7 (patch) | |
tree | 00d00c1987a63ee1c74a8f5864c954d7cad5ff65 | |
parent | 0d97937d994d4ec11f77661985be971a61daa6e3 (diff) | |
download | factory-boy-89631460a87d84790bc88766433e42f881937bf7.tar factory-boy-89631460a87d84790bc88766433e42f881937bf7.tar.gz |
Handle nested SubFactory with extra attributes.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r-- | factory/containers.py | 22 | ||||
-rw-r--r-- | factory/test_base.py | 26 |
2 files changed, 43 insertions, 5 deletions
diff --git a/factory/containers.py b/factory/containers.py index b46a19f..d3fbc5b 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -134,6 +134,21 @@ class DeclarationsHolder(object): def items(self): return list(self.iteritems()) + def _extract_sub_fields(self, base): + """Extract all subfields declaration from a given dict-like object. + + Will compare with attributes declared in the current object, and + will pop() values from the given base. + """ + sub_fields = dict() + + for key in list(base): + if ATTR_SPLITTER in key: + cls_name, attr_name = key.split(ATTR_SPLITTER, 1) + if cls_name in self: + sub_fields.setdefault(cls_name, {})[attr_name] = base.pop(key) + return sub_fields + def build_attributes(self, factory, create=False, extra=None): """Build the list of attributes based on class attributes.""" if not extra: @@ -143,11 +158,8 @@ class DeclarationsHolder(object): attributes = {} sub_fields = {} - for key in list(extra.keys()): - if ATTR_SPLITTER in key: - cls_name, attr_name = key.split(ATTR_SPLITTER, 1) - if cls_name in self: - sub_fields.setdefault(cls_name, {})[attr_name] = extra.pop(key) + for base in (self._unordered, self._ordered, extra): + sub_fields.update(self._extract_sub_fields(base)) # For fields in _unordered, use the value from extra if any; otherwise, # use the default value. diff --git a/factory/test_base.py b/factory/test_base.py index 8772f8b..e28b5eb 100644 --- a/factory/test_base.py +++ b/factory/test_base.py @@ -326,6 +326,32 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase): self.assertEqual('x0x', test_model.two.one) self.assertEqual('x0xx0x', test_model.two.two) + def testNestedSubFactory(self): + """Test nested sub-factories.""" + + class TestObject(object): + def __init__(self, **kwargs): + for k, v in kwargs.iteritems(): + setattr(self, k, v) + + class TestObjectFactory(Factory): + FACTORY_FOR = TestObject + + class WrappingTestObjectFactory(Factory): + FACTORY_FOR = TestObject + + wrapped = declarations.SubFactory(TestObjectFactory) + wrapped_bis = declarations.SubFactory(TestObjectFactory, one=1) + + class OuterWrappingTestObjectFactory(Factory): + FACTORY_FOR = TestObject + + wrap = declarations.SubFactory(WrappingTestObjectFactory, wrapped__two=2) + + outer = OuterWrappingTestObjectFactory.build() + self.assertEqual(outer.wrap.wrapped.two, 2) + self.assertEqual(outer.wrap.wrapped_bis.one, 1) + def testStubStrategy(self): Factory.default_strategy = STUB_STRATEGY |