summaryrefslogtreecommitdiff
path: root/tests/test_django.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_django.py')
-rw-r--r--tests/test_django.py80
1 files changed, 79 insertions, 1 deletions
diff --git a/tests/test_django.py b/tests/test_django.py
index e4bbc2b..18ffa6b 100644
--- a/tests/test_django.py
+++ b/tests/test_django.py
@@ -24,6 +24,7 @@ import os
import factory
import factory.django
+from factory.helpers import prevent_signals
try:
@@ -42,7 +43,7 @@ except ImportError: # pragma: no cover
Image = None
-from .compat import is_python2, unittest
+from .compat import is_python2, unittest, mock
from . import testdata
from . import tools
@@ -55,6 +56,7 @@ if django is not None:
from django.db import models as django_models
from django.test import simple as django_test_simple
from django.test import utils as django_test_utils
+ from django.db.models import signals
from .djapp import models
else: # pragma: no cover
django_test = unittest
@@ -70,6 +72,7 @@ else: # pragma: no cover
models.NonIntegerPk = Fake
models.WithFile = Fake
models.WithImage = Fake
+ models.WithSignals = Fake
test_state = {}
@@ -142,6 +145,10 @@ class WithImageFactory(factory.django.DjangoModelFactory):
animage = factory.django.ImageField()
+class WithSignalsFactory(factory.django.DjangoModelFactory):
+ FACTORY_FOR = models.WithSignals
+
+
@unittest.skipIf(django is None, "Django not installed.")
class DjangoPkSequenceTestCase(django_test.TestCase):
def setUp(self):
@@ -511,5 +518,76 @@ class DjangoImageFieldTestCase(unittest.TestCase):
self.assertFalse(o.animage)
+@unittest.skipIf(django is None, "Django not installed.")
+class PreventSignalsTestCase(unittest.TestCase):
+ def setUp(self):
+ self.handlers = mock.MagicMock()
+
+ signals.pre_init.connect(self.handlers.pre_init)
+ signals.pre_save.connect(self.handlers.pre_save)
+ signals.post_save.connect(self.handlers.post_save)
+
+ def tearDown(self):
+ signals.pre_init.disconnect(self.handlers.pre_init)
+ signals.pre_save.disconnect(self.handlers.pre_save)
+ signals.post_save.disconnect(self.handlers.post_save)
+
+ def test_signals(self):
+ WithSignalsFactory()
+
+ self.assertEqual(self.handlers.pre_save.call_count, 1)
+ self.assertEqual(self.handlers.post_save.call_count, 1)
+
+ def test_context_manager(self):
+ with prevent_signals(signals.pre_save, signals.post_save):
+ WithSignalsFactory()
+
+ self.assertEqual(self.handlers.pre_init.call_count, 1)
+ self.assertFalse(self.handlers.pre_save.called)
+ self.assertFalse(self.handlers.post_save.called)
+
+ self.test_signals()
+
+ def test_class_decorator(self):
+ @prevent_signals(signals.pre_save, signals.post_save)
+ class WithSignalsDecoratedFactory(factory.django.DjangoModelFactory):
+ FACTORY_FOR = models.WithSignals
+
+ WithSignalsDecoratedFactory()
+
+ self.assertEqual(self.handlers.pre_init.call_count, 1)
+ self.assertFalse(self.handlers.pre_save.called)
+ self.assertFalse(self.handlers.post_save.called)
+
+ self.test_signals()
+
+ def test_function_decorator(self):
+ @prevent_signals(signals.pre_save, signals.post_save)
+ def foo():
+ WithSignalsFactory()
+
+ foo()
+
+ self.assertEqual(self.handlers.pre_init.call_count, 1)
+ self.assertFalse(self.handlers.pre_save.called)
+ self.assertFalse(self.handlers.post_save.called)
+
+ self.test_signals()
+
+ def test_classmethod_decorator(self):
+ class Foo(object):
+ @classmethod
+ @prevent_signals(signals.pre_save, signals.post_save)
+ def generate(cls):
+ WithSignalsFactory()
+
+ Foo.generate()
+
+ self.assertEqual(self.handlers.pre_init.call_count, 1)
+ self.assertFalse(self.handlers.pre_save.called)
+ self.assertFalse(self.handlers.post_save.called)
+
+ self.test_signals()
+
if __name__ == '__main__': # pragma: no cover
unittest.main()