summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2011-09-05 22:17:28 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2011-09-05 22:17:28 +0200
commit2821684dec8a71e93728fc3e036f83592324b518 (patch)
tree0b78ff41c7179687f940fb4b57e7d1fac10150ba
parent89631460a87d84790bc88766433e42f881937bf7 (diff)
downloadfactory-boy-2821684dec8a71e93728fc3e036f83592324b518.tar
factory-boy-2821684dec8a71e93728fc3e036f83592324b518.tar.gz
Better fix for nested subfactories.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r--factory/containers.py26
-rw-r--r--factory/test_base.py27
2 files changed, 42 insertions, 11 deletions
diff --git a/factory/containers.py b/factory/containers.py
index d3fbc5b..48a4015 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -161,21 +161,25 @@ class DeclarationsHolder(object):
for base in (self._unordered, self._ordered, extra):
sub_fields.update(self._extract_sub_fields(base))
+ def make_value(key, val):
+ if key in extra:
+ val = extra.pop(key)
+ if isinstance(val, SubFactory):
+ new_val = val.evaluate(factory, create, sub_fields.get(key, {}))
+ elif isinstance(val, OrderedDeclaration):
+ wrapper = ObjectParamsWrapper(attributes)
+ new_val = val.evaluate(factory, wrapper)
+ else:
+ new_val = val
+
+ return new_val
+
# For fields in _unordered, use the value from extra if any; otherwise,
# use the default value.
for key, value in self._unordered.iteritems():
- attributes[key] = extra.get(key, value)
+ attributes[key] = make_value(key, value)
for key, value in self._ordered.iteritems():
- if key in extra:
- attributes[key] = extra[key]
- else:
- if isinstance(value, SubFactory):
- new_value = value.evaluate(factory, create,
- sub_fields.get(key, {}))
- else:
- wrapper = ObjectParamsWrapper(attributes)
- new_value = value.evaluate(factory, wrapper)
- attributes[key] = new_value
+ attributes[key] = make_value(key, value)
attributes.update(extra)
return attributes
diff --git a/factory/test_base.py b/factory/test_base.py
index e28b5eb..753f116 100644
--- a/factory/test_base.py
+++ b/factory/test_base.py
@@ -352,6 +352,33 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
self.assertEqual(outer.wrap.wrapped.two, 2)
self.assertEqual(outer.wrap.wrapped_bis.one, 1)
+ def testNestedSubFactoryWithOverriddenSubFactories(self):
+ """Test nested sub-factories, with attributes overridden with subfactories."""
+
+ class TestObject(object):
+ def __init__(self, **kwargs):
+ for k, v in kwargs.iteritems():
+ setattr(self, k, v)
+
+ class TestObjectFactory(Factory):
+ FACTORY_FOR = TestObject
+ two = 'two'
+
+ class WrappingTestObjectFactory(Factory):
+ FACTORY_FOR = TestObject
+
+ wrapped = declarations.SubFactory(TestObjectFactory)
+
+ class OuterWrappingTestObjectFactory(Factory):
+ FACTORY_FOR = TestObject
+
+ wrap = declarations.SubFactory(WrappingTestObjectFactory,
+ wrapped__two=declarations.SubFactory(TestObjectFactory, four=4))
+
+
+ outer = OuterWrappingTestObjectFactory.build()
+ self.assertEqual(outer.wrap.wrapped.two.four, 4)
+
def testStubStrategy(self):
Factory.default_strategy = STUB_STRATEGY