summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2013-03-04 23:18:02 +0100
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2013-03-04 23:18:02 +0100
commit7d792430e103984a91c102c33da79be2426bc632 (patch)
tree0044c449e933e93d4fecf1e80e6b361c6c3e8150
parent9422cf12516143650f1014f34f996260c00d4c0a (diff)
downloadfactory-boy-7d792430e103984a91c102c33da79be2426bc632.tar
factory-boy-7d792430e103984a91c102c33da79be2426bc632.tar.gz
Add a 'after post_generation' hook to Factory.
Use it in DjangoModelFactory to save objects again if a post_generation hook ran. Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r--docs/orms.rst5
-rw-r--r--docs/reference.rst13
-rw-r--r--factory/base.py24
-rw-r--r--factory/declarations.py2
-rw-r--r--tests/test_using.py21
5 files changed, 61 insertions, 4 deletions
diff --git a/docs/orms.rst b/docs/orms.rst
index eae31d9..d6ff3c3 100644
--- a/docs/orms.rst
+++ b/docs/orms.rst
@@ -32,5 +32,6 @@ All factories for a Django :class:`~django.db.models.Model` should use the
* :func:`~Factory.create()` uses :meth:`Model.objects.create() <django.db.models.query.QuerySet.create>`
* :func:`~Factory._setup_next_sequence()` selects the next unused primary key value
- * When using :class:`~factory.RelatedFactory` attributes, the base object will be
- :meth:`saved <django.db.models.Model.save>` once all post-generation hooks have run.
+ * When using :class:`~factory.RelatedFactory` or :class:`~factory.PostGeneration`
+ attributes, the base object will be :meth:`saved <django.db.models.Model.save>`
+ once all post-generation hooks have run.
diff --git a/docs/reference.rst b/docs/reference.rst
index e2246aa..d100b40 100644
--- a/docs/reference.rst
+++ b/docs/reference.rst
@@ -189,6 +189,19 @@ The :class:`Factory` class
.. OHAI_VIM*
+ .. classmethod:: _after_postgeneration(cls, obj, create, results=None)
+
+ :arg object obj: The object just generated
+ :arg bool create: Whether the object was 'built' or 'created'
+ :arg dict results: Map of post-generation declaration name to call
+ result
+
+ The :meth:`_after_postgeneration` is called once post-generation
+ declarations have been handled.
+
+ Its arguments allow to handle specifically some post-generation return
+ values, for instance.
+
.. _strategies:
diff --git a/factory/base.py b/factory/base.py
index 28d7cdb..3ebc746 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -616,12 +616,27 @@ class Factory(BaseFactory):
obj = cls._prepare(create, **attrs)
# Handle post-generation attributes
+ results = {}
for name, decl in sorted(postgen_declarations.items()):
extracted, extracted_kwargs = postgen_attributes[name]
- decl.call(obj, create, extracted, **extracted_kwargs)
+ results[name] = decl.call(obj, create, extracted, **extracted_kwargs)
+
+ cls._after_postgeneration(obj, create, results)
+
return obj
@classmethod
+ def _after_postgeneration(cls, obj, create, results=None):
+ """Hook called after post-generation declarations have been handled.
+
+ Args:
+ obj (object): the generated object
+ create (bool): whether the strategy was 'build' or 'create'
+ results (dict or None): result of post-generation declarations
+ """
+ pass
+
+ @classmethod
def build(cls, **kwargs):
attrs = cls.attributes(create=False, extra=kwargs)
return cls._generate(False, attrs)
@@ -657,6 +672,13 @@ class DjangoModelFactory(Factory):
"""Create an instance of the model, and save it to the database."""
return target_class._default_manager.create(*args, **kwargs)
+ @classmethod
+ def _after_postgeneration(cls, obj, create, results=None):
+ """Save again the instance if creating and at least one hook ran."""
+ if create and results:
+ # Some post-generation hooks ran, and may have modified us.
+ obj.save()
+
class MogoFactory(Factory):
"""Factory for mogo objects."""
diff --git a/factory/declarations.py b/factory/declarations.py
index d3d7659..366c2c8 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -412,7 +412,7 @@ class PostGeneration(PostGenerationDeclaration):
self.function = function
def call(self, obj, create, extracted=None, **kwargs):
- self.function(obj, create, extracted, **kwargs)
+ return self.function(obj, create, extracted, **kwargs)
def post_generation(*args, **kwargs):
diff --git a/tests/test_using.py b/tests/test_using.py
index e5af8fb..9bc466e 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -1216,6 +1216,27 @@ class PostGenerationTestCase(unittest.TestCase):
self.assertEqual(3, obj.one)
self.assertFalse(hasattr(obj, 'incr_one'))
+ def test_post_generation_hook(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+
+ one = 1
+
+ @factory.post_generation
+ def incr_one(self, _create, _increment):
+ self.one += 1
+ return 42
+
+ @classmethod
+ def _after_postgeneration(cls, obj, create, results):
+ obj.create = create
+ obj.results = results
+
+ obj = TestObjectFactory.build()
+ self.assertEqual(2, obj.one)
+ self.assertFalse(obj.create)
+ self.assertEqual({'incr_one': 42}, obj.results)
+
@tools.disable_warnings
def test_post_generation_calling(self):
class TestObjectFactory(factory.Factory):