diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-03-04 23:18:02 +0100 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-03-04 23:18:02 +0100 |
commit | 7d792430e103984a91c102c33da79be2426bc632 (patch) | |
tree | 0044c449e933e93d4fecf1e80e6b361c6c3e8150 | |
parent | 9422cf12516143650f1014f34f996260c00d4c0a (diff) | |
download | factory-boy-7d792430e103984a91c102c33da79be2426bc632.tar factory-boy-7d792430e103984a91c102c33da79be2426bc632.tar.gz |
Add a 'after post_generation' hook to Factory.
Use it in DjangoModelFactory to save objects again if a post_generation hook ran.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r-- | docs/orms.rst | 5 | ||||
-rw-r--r-- | docs/reference.rst | 13 | ||||
-rw-r--r-- | factory/base.py | 24 | ||||
-rw-r--r-- | factory/declarations.py | 2 | ||||
-rw-r--r-- | tests/test_using.py | 21 |
5 files changed, 61 insertions, 4 deletions
diff --git a/docs/orms.rst b/docs/orms.rst index eae31d9..d6ff3c3 100644 --- a/docs/orms.rst +++ b/docs/orms.rst @@ -32,5 +32,6 @@ All factories for a Django :class:`~django.db.models.Model` should use the * :func:`~Factory.create()` uses :meth:`Model.objects.create() <django.db.models.query.QuerySet.create>` * :func:`~Factory._setup_next_sequence()` selects the next unused primary key value - * When using :class:`~factory.RelatedFactory` attributes, the base object will be - :meth:`saved <django.db.models.Model.save>` once all post-generation hooks have run. + * When using :class:`~factory.RelatedFactory` or :class:`~factory.PostGeneration` + attributes, the base object will be :meth:`saved <django.db.models.Model.save>` + once all post-generation hooks have run. diff --git a/docs/reference.rst b/docs/reference.rst index e2246aa..d100b40 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -189,6 +189,19 @@ The :class:`Factory` class .. OHAI_VIM* + .. classmethod:: _after_postgeneration(cls, obj, create, results=None) + + :arg object obj: The object just generated + :arg bool create: Whether the object was 'built' or 'created' + :arg dict results: Map of post-generation declaration name to call + result + + The :meth:`_after_postgeneration` is called once post-generation + declarations have been handled. + + Its arguments allow to handle specifically some post-generation return + values, for instance. + .. _strategies: diff --git a/factory/base.py b/factory/base.py index 28d7cdb..3ebc746 100644 --- a/factory/base.py +++ b/factory/base.py @@ -616,12 +616,27 @@ class Factory(BaseFactory): obj = cls._prepare(create, **attrs) # Handle post-generation attributes + results = {} for name, decl in sorted(postgen_declarations.items()): extracted, extracted_kwargs = postgen_attributes[name] - decl.call(obj, create, extracted, **extracted_kwargs) + results[name] = decl.call(obj, create, extracted, **extracted_kwargs) + + cls._after_postgeneration(obj, create, results) + return obj @classmethod + def _after_postgeneration(cls, obj, create, results=None): + """Hook called after post-generation declarations have been handled. + + Args: + obj (object): the generated object + create (bool): whether the strategy was 'build' or 'create' + results (dict or None): result of post-generation declarations + """ + pass + + @classmethod def build(cls, **kwargs): attrs = cls.attributes(create=False, extra=kwargs) return cls._generate(False, attrs) @@ -657,6 +672,13 @@ class DjangoModelFactory(Factory): """Create an instance of the model, and save it to the database.""" return target_class._default_manager.create(*args, **kwargs) + @classmethod + def _after_postgeneration(cls, obj, create, results=None): + """Save again the instance if creating and at least one hook ran.""" + if create and results: + # Some post-generation hooks ran, and may have modified us. + obj.save() + class MogoFactory(Factory): """Factory for mogo objects.""" diff --git a/factory/declarations.py b/factory/declarations.py index d3d7659..366c2c8 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -412,7 +412,7 @@ class PostGeneration(PostGenerationDeclaration): self.function = function def call(self, obj, create, extracted=None, **kwargs): - self.function(obj, create, extracted, **kwargs) + return self.function(obj, create, extracted, **kwargs) def post_generation(*args, **kwargs): diff --git a/tests/test_using.py b/tests/test_using.py index e5af8fb..9bc466e 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -1216,6 +1216,27 @@ class PostGenerationTestCase(unittest.TestCase): self.assertEqual(3, obj.one) self.assertFalse(hasattr(obj, 'incr_one')) + def test_post_generation_hook(self): + class TestObjectFactory(factory.Factory): + FACTORY_FOR = TestObject + + one = 1 + + @factory.post_generation + def incr_one(self, _create, _increment): + self.one += 1 + return 42 + + @classmethod + def _after_postgeneration(cls, obj, create, results): + obj.create = create + obj.results = results + + obj = TestObjectFactory.build() + self.assertEqual(2, obj.one) + self.assertFalse(obj.create) + self.assertEqual({'incr_one': 42}, obj.results) + @tools.disable_warnings def test_post_generation_calling(self): class TestObjectFactory(factory.Factory): |