summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/__init__.py2
-rw-r--r--factory/django.py60
-rw-r--r--factory/helpers.py5
-rw-r--r--tests/djapp/models.py4
-rw-r--r--tests/test_django.py80
5 files changed, 150 insertions, 1 deletions
diff --git a/factory/__init__.py b/factory/__init__.py
index b4e63be..251306a 100644
--- a/factory/__init__.py
+++ b/factory/__init__.py
@@ -79,5 +79,7 @@ from .helpers import (
lazy_attribute_sequence,
container_attribute,
post_generation,
+
+ prevent_signals,
)
diff --git a/factory/django.py b/factory/django.py
index fee8e52..b502923 100644
--- a/factory/django.py
+++ b/factory/django.py
@@ -25,6 +25,9 @@ from __future__ import absolute_import
from __future__ import unicode_literals
import os
+import types
+import logging
+import functools
"""factory_boy extensions for use with the Django framework."""
@@ -39,6 +42,9 @@ from . import base
from . import declarations
from .compat import BytesIO, is_string
+logger = logging.getLogger('factory.generate')
+
+
def require_django():
"""Simple helper to ensure Django is available."""
@@ -214,3 +220,57 @@ class ImageField(FileField):
thumb.save(thumb_io, format=image_format)
return thumb_io.getvalue()
+
+class PreventSignals(object):
+ """Temporarily disables and then restores any django signals.
+
+ Args:
+ *signals (django.dispatch.dispatcher.Signal): any django signals
+
+ Examples:
+ with prevent_signals(pre_init):
+ user = UserFactory.build()
+ ...
+
+ @prevent_signals(pre_save, post_save)
+ class UserFactory(factory.Factory):
+ ...
+
+ @prevent_signals(post_save)
+ def generate_users():
+ UserFactory.create_batch(10)
+ """
+
+ def __init__(self, *signals):
+ self.signals = signals
+ self.paused = {}
+
+ def __enter__(self):
+ for signal in self.signals:
+ logger.debug('PreventSignals: Disabling signal handlers %r',
+ signal.receivers)
+
+ self.paused[signal] = signal.receivers
+ signal.receivers = []
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ for signal, receivers in self.paused.items():
+ logger.debug('PreventSignals: Restoring signal handlers %r',
+ receivers)
+
+ signal.receivers = receivers
+ self.paused = {}
+
+ def __call__(self, func):
+ if isinstance(func, types.FunctionType):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ with self:
+ return func(*args, **kwargs)
+ return wrapper
+
+ generate_method = getattr(func, '_generate', None)
+ if generate_method:
+ func._generate = classmethod(self(generate_method.__func__))
+
+ return func \ No newline at end of file
diff --git a/factory/helpers.py b/factory/helpers.py
index 37b41bf..719d76d 100644
--- a/factory/helpers.py
+++ b/factory/helpers.py
@@ -28,6 +28,7 @@ import logging
from . import base
from . import declarations
+from . import django
@contextlib.contextmanager
@@ -139,3 +140,7 @@ def container_attribute(func):
def post_generation(fun):
return declarations.PostGeneration(fun)
+
+
+def prevent_signals(*signals):
+ return django.PreventSignals(*signals) \ No newline at end of file
diff --git a/tests/djapp/models.py b/tests/djapp/models.py
index e98279d..a65b50a 100644
--- a/tests/djapp/models.py
+++ b/tests/djapp/models.py
@@ -74,3 +74,7 @@ if Image is not None: # PIL is available
else:
class WithImage(models.Model):
pass
+
+
+class WithSignals(models.Model):
+ foo = models.CharField(max_length=20) \ No newline at end of file
diff --git a/tests/test_django.py b/tests/test_django.py
index e4bbc2b..18ffa6b 100644
--- a/tests/test_django.py
+++ b/tests/test_django.py
@@ -24,6 +24,7 @@ import os
import factory
import factory.django
+from factory.helpers import prevent_signals
try:
@@ -42,7 +43,7 @@ except ImportError: # pragma: no cover
Image = None
-from .compat import is_python2, unittest
+from .compat import is_python2, unittest, mock
from . import testdata
from . import tools
@@ -55,6 +56,7 @@ if django is not None:
from django.db import models as django_models
from django.test import simple as django_test_simple
from django.test import utils as django_test_utils
+ from django.db.models import signals
from .djapp import models
else: # pragma: no cover
django_test = unittest
@@ -70,6 +72,7 @@ else: # pragma: no cover
models.NonIntegerPk = Fake
models.WithFile = Fake
models.WithImage = Fake
+ models.WithSignals = Fake
test_state = {}
@@ -142,6 +145,10 @@ class WithImageFactory(factory.django.DjangoModelFactory):
animage = factory.django.ImageField()
+class WithSignalsFactory(factory.django.DjangoModelFactory):
+ FACTORY_FOR = models.WithSignals
+
+
@unittest.skipIf(django is None, "Django not installed.")
class DjangoPkSequenceTestCase(django_test.TestCase):
def setUp(self):
@@ -511,5 +518,76 @@ class DjangoImageFieldTestCase(unittest.TestCase):
self.assertFalse(o.animage)
+@unittest.skipIf(django is None, "Django not installed.")
+class PreventSignalsTestCase(unittest.TestCase):
+ def setUp(self):
+ self.handlers = mock.MagicMock()
+
+ signals.pre_init.connect(self.handlers.pre_init)
+ signals.pre_save.connect(self.handlers.pre_save)
+ signals.post_save.connect(self.handlers.post_save)
+
+ def tearDown(self):
+ signals.pre_init.disconnect(self.handlers.pre_init)
+ signals.pre_save.disconnect(self.handlers.pre_save)
+ signals.post_save.disconnect(self.handlers.post_save)
+
+ def test_signals(self):
+ WithSignalsFactory()
+
+ self.assertEqual(self.handlers.pre_save.call_count, 1)
+ self.assertEqual(self.handlers.post_save.call_count, 1)
+
+ def test_context_manager(self):
+ with prevent_signals(signals.pre_save, signals.post_save):
+ WithSignalsFactory()
+
+ self.assertEqual(self.handlers.pre_init.call_count, 1)
+ self.assertFalse(self.handlers.pre_save.called)
+ self.assertFalse(self.handlers.post_save.called)
+
+ self.test_signals()
+
+ def test_class_decorator(self):
+ @prevent_signals(signals.pre_save, signals.post_save)
+ class WithSignalsDecoratedFactory(factory.django.DjangoModelFactory):
+ FACTORY_FOR = models.WithSignals
+
+ WithSignalsDecoratedFactory()
+
+ self.assertEqual(self.handlers.pre_init.call_count, 1)
+ self.assertFalse(self.handlers.pre_save.called)
+ self.assertFalse(self.handlers.post_save.called)
+
+ self.test_signals()
+
+ def test_function_decorator(self):
+ @prevent_signals(signals.pre_save, signals.post_save)
+ def foo():
+ WithSignalsFactory()
+
+ foo()
+
+ self.assertEqual(self.handlers.pre_init.call_count, 1)
+ self.assertFalse(self.handlers.pre_save.called)
+ self.assertFalse(self.handlers.post_save.called)
+
+ self.test_signals()
+
+ def test_classmethod_decorator(self):
+ class Foo(object):
+ @classmethod
+ @prevent_signals(signals.pre_save, signals.post_save)
+ def generate(cls):
+ WithSignalsFactory()
+
+ Foo.generate()
+
+ self.assertEqual(self.handlers.pre_init.call_count, 1)
+ self.assertFalse(self.handlers.pre_save.called)
+ self.assertFalse(self.handlers.post_save.called)
+
+ self.test_signals()
+
if __name__ == '__main__': # pragma: no cover
unittest.main()