summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/__init__.py1
-rw-r--r--factory/containers.py2
-rw-r--r--factory/declarations.py4
-rw-r--r--tests/test_using.py100
4 files changed, 104 insertions, 3 deletions
diff --git a/factory/__init__.py b/factory/__init__.py
index 950a64d..d2267f0 100644
--- a/factory/__init__.py
+++ b/factory/__init__.py
@@ -62,6 +62,7 @@ from declarations import (
SubFactory,
CircularSubFactory,
PostGeneration,
+ PostGenerationMethodCall,
RelatedFactory,
lazy_attribute,
diff --git a/factory/containers.py b/factory/containers.py
index 946fbd3..6834f60 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -167,7 +167,7 @@ class PostGenerationDeclarationDict(DeclarationDict):
class LazyValue(object):
"""Some kind of "lazy evaluating" object."""
- def evaluate(self, obj, containers=()):
+ def evaluate(self, obj, containers=()): # pragma: no cover
"""Compute the value, using the given object."""
raise NotImplementedError("This is an abstract method.")
diff --git a/factory/declarations.py b/factory/declarations.py
index 5e45255..77000f2 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -250,7 +250,7 @@ class ParameteredAttribute(OrderedDeclaration):
return self.generate(create, defaults)
- def generate(self, create, params):
+ def generate(self, create, params): # pragma: no cover
"""Actually generate the related attribute.
Args:
@@ -352,7 +352,7 @@ class PostGenerationDeclaration(object):
kwargs = utils.extract_dict(extract_prefix, attrs)
return extracted, kwargs
- def call(self, obj, create, extracted=None, **kwargs):
+ def call(self, obj, create, extracted=None, **kwargs): # pragma: no cover
"""Call this hook; no return value is expected.
Args:
diff --git a/tests/test_using.py b/tests/test_using.py
index 7b8b3d6..8620127 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -892,6 +892,78 @@ class SubFactoryTestCase(unittest.TestCase):
self.assertEqual(outer.side_a.inner_from_a.a, outer.foo * 2)
self.assertEqual(outer.side_a.inner_from_a.b, 4)
+ def test_nonstrict_container_attribute(self):
+ class TestModel2(FakeModel):
+ pass
+
+ class TestModelFactory(FakeModelFactory):
+ FACTORY_FOR = TestModel
+ one = 3
+ two = factory.ContainerAttribute(lambda obj, containers: len(containers or []), strict=False)
+
+ class TestModel2Factory(FakeModelFactory):
+ FACTORY_FOR = TestModel2
+ one = 1
+ two = factory.SubFactory(TestModelFactory, one=1)
+
+ obj = TestModel2Factory.build()
+ self.assertEqual(1, obj.one)
+ self.assertEqual(1, obj.two.one)
+ self.assertEqual(1, obj.two.two)
+
+ obj = TestModelFactory()
+ self.assertEqual(3, obj.one)
+ self.assertEqual(0, obj.two)
+
+ def test_strict_container_attribute(self):
+ class TestModel2(FakeModel):
+ pass
+
+ class TestModelFactory(FakeModelFactory):
+ FACTORY_FOR = TestModel
+ one = 3
+ two = factory.ContainerAttribute(lambda obj, containers: len(containers or []), strict=True)
+
+ class TestModel2Factory(FakeModelFactory):
+ FACTORY_FOR = TestModel2
+ one = 1
+ two = factory.SubFactory(TestModelFactory, one=1)
+
+ obj = TestModel2Factory.build()
+ self.assertEqual(1, obj.one)
+ self.assertEqual(1, obj.two.one)
+ self.assertEqual(1, obj.two.two)
+
+ self.assertRaises(TypeError, TestModelFactory.build)
+
+ def test_function_container_attribute(self):
+ class TestModel2(FakeModel):
+ pass
+
+ class TestModelFactory(FakeModelFactory):
+ FACTORY_FOR = TestModel
+ one = 3
+
+ @factory.container_attribute
+ def two(self, containers):
+ if containers:
+ return len(containers)
+ return 42
+
+ class TestModel2Factory(FakeModelFactory):
+ FACTORY_FOR = TestModel2
+ one = 1
+ two = factory.SubFactory(TestModelFactory, one=1)
+
+ obj = TestModel2Factory.build()
+ self.assertEqual(1, obj.one)
+ self.assertEqual(1, obj.two.one)
+ self.assertEqual(1, obj.two.two)
+
+ obj = TestModelFactory()
+ self.assertEqual(3, obj.one)
+ self.assertEqual(42, obj.two)
+
class IteratorTestCase(unittest.TestCase):
@@ -1021,6 +1093,34 @@ class PostGenerationTestCase(unittest.TestCase):
obj = TestObjectFactory.build(bar=42, bar__foo=13)
+ def test_post_generation_method_call(self):
+ calls = []
+
+ class TestObject(object):
+ def __init__(self, one=None, two=None):
+ self.one = one
+ self.two = two
+ self.extra = None
+
+ def call(self, *args, **kwargs):
+ self.extra = (args, kwargs)
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = 3
+ two = 2
+ post_call = factory.PostGenerationMethodCall('call', one=1)
+
+ obj = TestObjectFactory.build()
+ self.assertEqual(3, obj.one)
+ self.assertEqual(2, obj.two)
+ self.assertEqual(((), {'one': 1}), obj.extra)
+
+ obj = TestObjectFactory.build(post_call__one=2, post_call__two=3)
+ self.assertEqual(3, obj.one)
+ self.assertEqual(2, obj.two)
+ self.assertEqual(((), {'one': 2, 'two': 3}), obj.extra)
+
def test_related_factory(self):
class TestRelatedObject(object):
def __init__(self, obj=None, one=None, two=None):