diff options
-rw-r--r-- | factory/__init__.py | 1 | ||||
-rw-r--r-- | factory/containers.py | 2 | ||||
-rw-r--r-- | factory/declarations.py | 4 | ||||
-rw-r--r-- | tests/test_using.py | 100 |
4 files changed, 104 insertions, 3 deletions
diff --git a/factory/__init__.py b/factory/__init__.py index 950a64d..d2267f0 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -62,6 +62,7 @@ from declarations import ( SubFactory, CircularSubFactory, PostGeneration, + PostGenerationMethodCall, RelatedFactory, lazy_attribute, diff --git a/factory/containers.py b/factory/containers.py index 946fbd3..6834f60 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -167,7 +167,7 @@ class PostGenerationDeclarationDict(DeclarationDict): class LazyValue(object): """Some kind of "lazy evaluating" object.""" - def evaluate(self, obj, containers=()): + def evaluate(self, obj, containers=()): # pragma: no cover """Compute the value, using the given object.""" raise NotImplementedError("This is an abstract method.") diff --git a/factory/declarations.py b/factory/declarations.py index 5e45255..77000f2 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -250,7 +250,7 @@ class ParameteredAttribute(OrderedDeclaration): return self.generate(create, defaults) - def generate(self, create, params): + def generate(self, create, params): # pragma: no cover """Actually generate the related attribute. Args: @@ -352,7 +352,7 @@ class PostGenerationDeclaration(object): kwargs = utils.extract_dict(extract_prefix, attrs) return extracted, kwargs - def call(self, obj, create, extracted=None, **kwargs): + def call(self, obj, create, extracted=None, **kwargs): # pragma: no cover """Call this hook; no return value is expected. Args: diff --git a/tests/test_using.py b/tests/test_using.py index 7b8b3d6..8620127 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -892,6 +892,78 @@ class SubFactoryTestCase(unittest.TestCase): self.assertEqual(outer.side_a.inner_from_a.a, outer.foo * 2) self.assertEqual(outer.side_a.inner_from_a.b, 4) + def test_nonstrict_container_attribute(self): + class TestModel2(FakeModel): + pass + + class TestModelFactory(FakeModelFactory): + FACTORY_FOR = TestModel + one = 3 + two = factory.ContainerAttribute(lambda obj, containers: len(containers or []), strict=False) + + class TestModel2Factory(FakeModelFactory): + FACTORY_FOR = TestModel2 + one = 1 + two = factory.SubFactory(TestModelFactory, one=1) + + obj = TestModel2Factory.build() + self.assertEqual(1, obj.one) + self.assertEqual(1, obj.two.one) + self.assertEqual(1, obj.two.two) + + obj = TestModelFactory() + self.assertEqual(3, obj.one) + self.assertEqual(0, obj.two) + + def test_strict_container_attribute(self): + class TestModel2(FakeModel): + pass + + class TestModelFactory(FakeModelFactory): + FACTORY_FOR = TestModel + one = 3 + two = factory.ContainerAttribute(lambda obj, containers: len(containers or []), strict=True) + + class TestModel2Factory(FakeModelFactory): + FACTORY_FOR = TestModel2 + one = 1 + two = factory.SubFactory(TestModelFactory, one=1) + + obj = TestModel2Factory.build() + self.assertEqual(1, obj.one) + self.assertEqual(1, obj.two.one) + self.assertEqual(1, obj.two.two) + + self.assertRaises(TypeError, TestModelFactory.build) + + def test_function_container_attribute(self): + class TestModel2(FakeModel): + pass + + class TestModelFactory(FakeModelFactory): + FACTORY_FOR = TestModel + one = 3 + + @factory.container_attribute + def two(self, containers): + if containers: + return len(containers) + return 42 + + class TestModel2Factory(FakeModelFactory): + FACTORY_FOR = TestModel2 + one = 1 + two = factory.SubFactory(TestModelFactory, one=1) + + obj = TestModel2Factory.build() + self.assertEqual(1, obj.one) + self.assertEqual(1, obj.two.one) + self.assertEqual(1, obj.two.two) + + obj = TestModelFactory() + self.assertEqual(3, obj.one) + self.assertEqual(42, obj.two) + class IteratorTestCase(unittest.TestCase): @@ -1021,6 +1093,34 @@ class PostGenerationTestCase(unittest.TestCase): obj = TestObjectFactory.build(bar=42, bar__foo=13) + def test_post_generation_method_call(self): + calls = [] + + class TestObject(object): + def __init__(self, one=None, two=None): + self.one = one + self.two = two + self.extra = None + + def call(self, *args, **kwargs): + self.extra = (args, kwargs) + + class TestObjectFactory(factory.Factory): + FACTORY_FOR = TestObject + one = 3 + two = 2 + post_call = factory.PostGenerationMethodCall('call', one=1) + + obj = TestObjectFactory.build() + self.assertEqual(3, obj.one) + self.assertEqual(2, obj.two) + self.assertEqual(((), {'one': 1}), obj.extra) + + obj = TestObjectFactory.build(post_call__one=2, post_call__two=3) + self.assertEqual(3, obj.one) + self.assertEqual(2, obj.two) + self.assertEqual(((), {'one': 2, 'two': 3}), obj.extra) + def test_related_factory(self): class TestRelatedObject(object): def __init__(self, obj=None, one=None, two=None): |