diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-09-24 20:52:26 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-09-24 20:52:26 +0200 |
commit | 8e37debdc17c5649b8d6b2cf035fb58c4ad3c077 (patch) | |
tree | 439b52a21ad1f11e8b4abc37d20ecb190b2ef3e2 | |
parent | 7fe9dcaa8494e73d57613d1288b4f86c4cba5bf0 (diff) | |
download | factory-boy-8e37debdc17c5649b8d6b2cf035fb58c4ad3c077.tar factory-boy-8e37debdc17c5649b8d6b2cf035fb58c4ad3c077.tar.gz |
Lint
-rw-r--r-- | factory/containers.py | 11 | ||||
-rw-r--r-- | tests/test_using.py | 32 |
2 files changed, 35 insertions, 8 deletions
diff --git a/factory/containers.py b/factory/containers.py index 4975036..7a4c5db 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -240,12 +240,13 @@ class AttributeBuilder(object): self._attrs = factory.declarations(extra) self._log_ctx = log_ctx + initial_declarations = factory.declarations({}) attrs_with_subfields = [ - k for k, v in self._attrs.items() + k for k, v in initial_declarations.items() if self.has_subfields(v)] self._subfields = utils.multi_extract_dict( - attrs_with_subfields, self._attrs) + attrs_with_subfields, self._attrs) def has_subfields(self, value): return isinstance(value, declarations.ParameteredAttribute) @@ -270,9 +271,9 @@ class AttributeBuilder(object): for k, v in self._attrs.items(): if isinstance(v, declarations.OrderedDeclaration): v = OrderedDeclarationWrapper(v, - sequence=sequence, - create=create, - extra=self._subfields.get(k, {}), + sequence=sequence, + create=create, + extra=self._subfields.get(k, {}), ) wrapped_attrs[k] = v diff --git a/tests/test_using.py b/tests/test_using.py index 01e950f..3979cd0 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -1018,9 +1018,9 @@ class SubFactoryTestCase(unittest.TestCase): class TestModel2Factory(FakeModelFactory): FACTORY_FOR = TestModel2 two = factory.SubFactory(TestModelFactory, - one=factory.Sequence(lambda n: 'x%dx' % n), - two=factory.LazyAttribute( - lambda o: '%s%s' % (o.one, o.one))) + one=factory.Sequence(lambda n: 'x%dx' % n), + two=factory.LazyAttribute(lambda o: '%s%s' % (o.one, o.one)), + ) test_model = TestModel2Factory(one=42) self.assertEqual('x0x', test_model.two.one) @@ -1128,6 +1128,32 @@ class SubFactoryTestCase(unittest.TestCase): self.assertEqual(outer.wrap.wrapped.two.four, 4) self.assertEqual(outer.wrap.friend, 5) + def test_nested_subfactory_with_override(self): + """Tests replacing a SubFactory field with an actual value.""" + + # The test class + class TestObject(object): + def __init__(self, two='one', wrapped=None): + self.two = two + self.wrapped = wrapped + + # Innermost factory + class TestObjectFactory(factory.Factory): + FACTORY_FOR = TestObject + two = 'two' + + # Intermediary factory + class WrappingTestObjectFactory(factory.Factory): + FACTORY_FOR = TestObject + + wrapped = factory.SubFactory(TestObjectFactory) + wrapped__two = 'three' + + obj = TestObject(two='four') + outer = WrappingTestObjectFactory(wrapped=obj) + self.assertEqual(obj, outer.wrapped) + self.assertEqual('four', outer.wrapped.two) + def test_sub_factory_and_inheritance(self): """Test inheriting from a factory with subfactories, overriding.""" class TestObject(object): |