diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2011-09-05 22:17:28 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2011-09-05 22:17:28 +0200 |
commit | 2821684dec8a71e93728fc3e036f83592324b518 (patch) | |
tree | 0b78ff41c7179687f940fb4b57e7d1fac10150ba | |
parent | 89631460a87d84790bc88766433e42f881937bf7 (diff) | |
download | factory-boy-2821684dec8a71e93728fc3e036f83592324b518.tar factory-boy-2821684dec8a71e93728fc3e036f83592324b518.tar.gz |
Better fix for nested subfactories.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r-- | factory/containers.py | 26 | ||||
-rw-r--r-- | factory/test_base.py | 27 |
2 files changed, 42 insertions, 11 deletions
diff --git a/factory/containers.py b/factory/containers.py index d3fbc5b..48a4015 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -161,21 +161,25 @@ class DeclarationsHolder(object): for base in (self._unordered, self._ordered, extra): sub_fields.update(self._extract_sub_fields(base)) + def make_value(key, val): + if key in extra: + val = extra.pop(key) + if isinstance(val, SubFactory): + new_val = val.evaluate(factory, create, sub_fields.get(key, {})) + elif isinstance(val, OrderedDeclaration): + wrapper = ObjectParamsWrapper(attributes) + new_val = val.evaluate(factory, wrapper) + else: + new_val = val + + return new_val + # For fields in _unordered, use the value from extra if any; otherwise, # use the default value. for key, value in self._unordered.iteritems(): - attributes[key] = extra.get(key, value) + attributes[key] = make_value(key, value) for key, value in self._ordered.iteritems(): - if key in extra: - attributes[key] = extra[key] - else: - if isinstance(value, SubFactory): - new_value = value.evaluate(factory, create, - sub_fields.get(key, {})) - else: - wrapper = ObjectParamsWrapper(attributes) - new_value = value.evaluate(factory, wrapper) - attributes[key] = new_value + attributes[key] = make_value(key, value) attributes.update(extra) return attributes diff --git a/factory/test_base.py b/factory/test_base.py index e28b5eb..753f116 100644 --- a/factory/test_base.py +++ b/factory/test_base.py @@ -352,6 +352,33 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase): self.assertEqual(outer.wrap.wrapped.two, 2) self.assertEqual(outer.wrap.wrapped_bis.one, 1) + def testNestedSubFactoryWithOverriddenSubFactories(self): + """Test nested sub-factories, with attributes overridden with subfactories.""" + + class TestObject(object): + def __init__(self, **kwargs): + for k, v in kwargs.iteritems(): + setattr(self, k, v) + + class TestObjectFactory(Factory): + FACTORY_FOR = TestObject + two = 'two' + + class WrappingTestObjectFactory(Factory): + FACTORY_FOR = TestObject + + wrapped = declarations.SubFactory(TestObjectFactory) + + class OuterWrappingTestObjectFactory(Factory): + FACTORY_FOR = TestObject + + wrap = declarations.SubFactory(WrappingTestObjectFactory, + wrapped__two=declarations.SubFactory(TestObjectFactory, four=4)) + + + outer = OuterWrappingTestObjectFactory.build() + self.assertEqual(outer.wrap.wrapped.two.four, 4) + def testStubStrategy(self): Factory.default_strategy = STUB_STRATEGY |