summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/orms.rst5
-rw-r--r--docs/reference.rst13
-rw-r--r--factory/base.py24
-rw-r--r--factory/declarations.py2
-rw-r--r--tests/test_using.py21
5 files changed, 61 insertions, 4 deletions
diff --git a/docs/orms.rst b/docs/orms.rst
index eae31d9..d6ff3c3 100644
--- a/docs/orms.rst
+++ b/docs/orms.rst
@@ -32,5 +32,6 @@ All factories for a Django :class:`~django.db.models.Model` should use the
* :func:`~Factory.create()` uses :meth:`Model.objects.create() <django.db.models.query.QuerySet.create>`
* :func:`~Factory._setup_next_sequence()` selects the next unused primary key value
- * When using :class:`~factory.RelatedFactory` attributes, the base object will be
- :meth:`saved <django.db.models.Model.save>` once all post-generation hooks have run.
+ * When using :class:`~factory.RelatedFactory` or :class:`~factory.PostGeneration`
+ attributes, the base object will be :meth:`saved <django.db.models.Model.save>`
+ once all post-generation hooks have run.
diff --git a/docs/reference.rst b/docs/reference.rst
index e2246aa..d100b40 100644
--- a/docs/reference.rst
+++ b/docs/reference.rst
@@ -189,6 +189,19 @@ The :class:`Factory` class
.. OHAI_VIM*
+ .. classmethod:: _after_postgeneration(cls, obj, create, results=None)
+
+ :arg object obj: The object just generated
+ :arg bool create: Whether the object was 'built' or 'created'
+ :arg dict results: Map of post-generation declaration name to call
+ result
+
+ The :meth:`_after_postgeneration` is called once post-generation
+ declarations have been handled.
+
+ Its arguments allow to handle specifically some post-generation return
+ values, for instance.
+
.. _strategies:
diff --git a/factory/base.py b/factory/base.py
index 28d7cdb..3ebc746 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -616,12 +616,27 @@ class Factory(BaseFactory):
obj = cls._prepare(create, **attrs)
# Handle post-generation attributes
+ results = {}
for name, decl in sorted(postgen_declarations.items()):
extracted, extracted_kwargs = postgen_attributes[name]
- decl.call(obj, create, extracted, **extracted_kwargs)
+ results[name] = decl.call(obj, create, extracted, **extracted_kwargs)
+
+ cls._after_postgeneration(obj, create, results)
+
return obj
@classmethod
+ def _after_postgeneration(cls, obj, create, results=None):
+ """Hook called after post-generation declarations have been handled.
+
+ Args:
+ obj (object): the generated object
+ create (bool): whether the strategy was 'build' or 'create'
+ results (dict or None): result of post-generation declarations
+ """
+ pass
+
+ @classmethod
def build(cls, **kwargs):
attrs = cls.attributes(create=False, extra=kwargs)
return cls._generate(False, attrs)
@@ -657,6 +672,13 @@ class DjangoModelFactory(Factory):
"""Create an instance of the model, and save it to the database."""
return target_class._default_manager.create(*args, **kwargs)
+ @classmethod
+ def _after_postgeneration(cls, obj, create, results=None):
+ """Save again the instance if creating and at least one hook ran."""
+ if create and results:
+ # Some post-generation hooks ran, and may have modified us.
+ obj.save()
+
class MogoFactory(Factory):
"""Factory for mogo objects."""
diff --git a/factory/declarations.py b/factory/declarations.py
index d3d7659..366c2c8 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -412,7 +412,7 @@ class PostGeneration(PostGenerationDeclaration):
self.function = function
def call(self, obj, create, extracted=None, **kwargs):
- self.function(obj, create, extracted, **kwargs)
+ return self.function(obj, create, extracted, **kwargs)
def post_generation(*args, **kwargs):
diff --git a/tests/test_using.py b/tests/test_using.py
index e5af8fb..9bc466e 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -1216,6 +1216,27 @@ class PostGenerationTestCase(unittest.TestCase):
self.assertEqual(3, obj.one)
self.assertFalse(hasattr(obj, 'incr_one'))
+ def test_post_generation_hook(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+
+ one = 1
+
+ @factory.post_generation
+ def incr_one(self, _create, _increment):
+ self.one += 1
+ return 42
+
+ @classmethod
+ def _after_postgeneration(cls, obj, create, results):
+ obj.create = create
+ obj.results = results
+
+ obj = TestObjectFactory.build()
+ self.assertEqual(2, obj.one)
+ self.assertFalse(obj.create)
+ self.assertEqual({'incr_one': 42}, obj.results)
+
@tools.disable_warnings
def test_post_generation_calling(self):
class TestObjectFactory(factory.Factory):