summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2012-05-04 01:17:59 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2012-05-04 01:17:59 +0200
commitd96c651f51b25988235ff79b50c7f9355fb16dd7 (patch)
tree02972dccdf0301df2e8781d44de47fd7d3b99598
parentfbd66ede5617a40f73dfb3f518c9887d48ab401e (diff)
downloadfactory-boy-d96c651f51b25988235ff79b50c7f9355fb16dd7.tar
factory-boy-d96c651f51b25988235ff79b50c7f9355fb16dd7.tar.gz
Only absorb dependant arguments for SubFactory fields (Closes #15).
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r--factory/containers.py8
-rw-r--r--tests/test_base.py18
-rw-r--r--tests/test_containers.py18
-rw-r--r--tests/test_declarations.py20
-rw-r--r--tests/test_using.py13
5 files changed, 75 insertions, 2 deletions
diff --git a/factory/containers.py b/factory/containers.py
index 9f480cc..946fbd3 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -242,7 +242,13 @@ class AttributeBuilder(object):
self.factory = factory
self._containers = extra.pop('__containers', None)
self._attrs = factory.declarations(extra)
- self._subfields = utils.multi_extract_dict(self._attrs.keys(), self._attrs)
+
+ attrs_with_subfields = [k for k, v in self._attrs.items() if self.has_subfields(v)]
+
+ self._subfields = utils.multi_extract_dict(attrs_with_subfields, self._attrs)
+
+ def has_subfields(self, value):
+ return isinstance(value, declarations.SubFactory)
def build(self, create):
"""Build a dictionary of attributes.
diff --git a/tests/test_base.py b/tests/test_base.py
index e0a6547..7575ee2 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -229,5 +229,23 @@ class FactoryCreationTestCase(unittest.TestCase):
self.assertTrue('autodiscovery' not in str(e))
+class PostGenerationParsingTestCase(unittest.TestCase):
+
+ def test_extraction(self):
+ class TestObjectFactory(base.Factory):
+ foo = declarations.PostGenerationDeclaration()
+
+ self.assertIn('foo', TestObjectFactory._postgen_declarations)
+
+ def test_classlevel_extraction(self):
+ class TestObjectFactory(base.Factory):
+ foo = declarations.PostGenerationDeclaration()
+ foo__bar = 42
+
+ self.assertIn('foo', TestObjectFactory._postgen_declarations)
+ self.assertIn('foo__bar', TestObjectFactory._declarations)
+
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_containers.py b/tests/test_containers.py
index 55fe576..797c480 100644
--- a/tests/test_containers.py
+++ b/tests/test_containers.py
@@ -316,6 +316,24 @@ class AttributeBuilderTestCase(unittest.TestCase):
ab = containers.AttributeBuilder(FakeFactory, {'one': 4, 'three': la})
self.assertEqual({'one': 4, 'two': 8, 'three': 8}, ab.build(create=False))
+ def test_subfields(self):
+ class FakeInnerFactory(object):
+ pass
+
+ sf = declarations.SubFactory(FakeInnerFactory)
+
+ class FakeFactory(object):
+ @classmethod
+ def declarations(cls, extra):
+ d = {'one': sf, 'two': 2}
+ d.update(extra)
+ return d
+
+ ab = containers.AttributeBuilder(FakeFactory, {'one__blah': 1, 'two__bar': 2})
+ self.assertTrue(ab.has_subfields(sf))
+ self.assertEqual(['one'], ab._subfields.keys())
+ self.assertEqual(2, ab._attrs['two__bar'])
+
def test_sub_factory(self):
pass
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index f15645f..1c0502b 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -21,7 +21,8 @@
# THE SOFTWARE.
-from factory.declarations import deepgetattr, OrderedDeclaration, Sequence
+from factory.declarations import deepgetattr, OrderedDeclaration, \
+ PostGenerationDeclaration, Sequence
from .compat import unittest
@@ -55,6 +56,23 @@ class DigTestCase(unittest.TestCase):
self.assertEqual(42, deepgetattr(obj, 'a.b.c.n.x', 42))
+class PostGenerationDeclarationTestCase(unittest.TestCase):
+ def test_extract_no_prefix(self):
+ decl = PostGenerationDeclaration()
+
+ extracted, kwargs = decl.extract('foo', {'foo': 13, 'foo__bar': 42})
+ self.assertEqual(extracted, 13)
+ self.assertEqual(kwargs, {'bar': 42})
+
+ def test_extract_with_prefix(self):
+ decl = PostGenerationDeclaration(extract_prefix='blah')
+
+ extracted, kwargs = decl.extract('foo',
+ {'foo': 13, 'foo__bar': 42, 'blah': 42, 'blah__baz': 1})
+ self.assertEqual(extracted, 42)
+ self.assertEqual(kwargs, {'baz': 1})
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_using.py b/tests/test_using.py
index 20acd79..3bb0959 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -868,6 +868,19 @@ class PostGenerationDeclarationTestCase(unittest.TestCase):
self.assertEqual(4, obj.one)
self.assertFalse(hasattr(obj, 'incr_one'))
+ def test_post_generation_extraction_lambda(self):
+
+ def my_lambda(obj, create, extracted, **kwargs):
+ self.assertTrue(isinstance(obj, TestObject))
+ self.assertFalse(create)
+ self.assertEqual(extracted, 42)
+ self.assertEqual(kwargs, {'foo': 13})
+
+ class TestObjectFactory(factory.Factory):
+ bar = factory.PostGeneration(my_lambda)
+
+ obj = TestObjectFactory.build(bar=42, bar__foo=13)
+
def test_related_factory(self):
class TestRelatedObject(object):
def __init__(self, obj=None, one=None, two=None):