diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/djapp/models.py | 4 | ||||
-rw-r--r-- | tests/test_django.py | 80 |
2 files changed, 83 insertions, 1 deletions
diff --git a/tests/djapp/models.py b/tests/djapp/models.py index e98279d..a65b50a 100644 --- a/tests/djapp/models.py +++ b/tests/djapp/models.py @@ -74,3 +74,7 @@ if Image is not None: # PIL is available else: class WithImage(models.Model): pass + + +class WithSignals(models.Model): + foo = models.CharField(max_length=20)
\ No newline at end of file diff --git a/tests/test_django.py b/tests/test_django.py index e4bbc2b..18ffa6b 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -24,6 +24,7 @@ import os import factory import factory.django +from factory.helpers import prevent_signals try: @@ -42,7 +43,7 @@ except ImportError: # pragma: no cover Image = None -from .compat import is_python2, unittest +from .compat import is_python2, unittest, mock from . import testdata from . import tools @@ -55,6 +56,7 @@ if django is not None: from django.db import models as django_models from django.test import simple as django_test_simple from django.test import utils as django_test_utils + from django.db.models import signals from .djapp import models else: # pragma: no cover django_test = unittest @@ -70,6 +72,7 @@ else: # pragma: no cover models.NonIntegerPk = Fake models.WithFile = Fake models.WithImage = Fake + models.WithSignals = Fake test_state = {} @@ -142,6 +145,10 @@ class WithImageFactory(factory.django.DjangoModelFactory): animage = factory.django.ImageField() +class WithSignalsFactory(factory.django.DjangoModelFactory): + FACTORY_FOR = models.WithSignals + + @unittest.skipIf(django is None, "Django not installed.") class DjangoPkSequenceTestCase(django_test.TestCase): def setUp(self): @@ -511,5 +518,76 @@ class DjangoImageFieldTestCase(unittest.TestCase): self.assertFalse(o.animage) +@unittest.skipIf(django is None, "Django not installed.") +class PreventSignalsTestCase(unittest.TestCase): + def setUp(self): + self.handlers = mock.MagicMock() + + signals.pre_init.connect(self.handlers.pre_init) + signals.pre_save.connect(self.handlers.pre_save) + signals.post_save.connect(self.handlers.post_save) + + def tearDown(self): + signals.pre_init.disconnect(self.handlers.pre_init) + signals.pre_save.disconnect(self.handlers.pre_save) + signals.post_save.disconnect(self.handlers.post_save) + + def test_signals(self): + WithSignalsFactory() + + self.assertEqual(self.handlers.pre_save.call_count, 1) + self.assertEqual(self.handlers.post_save.call_count, 1) + + def test_context_manager(self): + with prevent_signals(signals.pre_save, signals.post_save): + WithSignalsFactory() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + + def test_class_decorator(self): + @prevent_signals(signals.pre_save, signals.post_save) + class WithSignalsDecoratedFactory(factory.django.DjangoModelFactory): + FACTORY_FOR = models.WithSignals + + WithSignalsDecoratedFactory() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + + def test_function_decorator(self): + @prevent_signals(signals.pre_save, signals.post_save) + def foo(): + WithSignalsFactory() + + foo() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + + def test_classmethod_decorator(self): + class Foo(object): + @classmethod + @prevent_signals(signals.pre_save, signals.post_save) + def generate(cls): + WithSignalsFactory() + + Foo.generate() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + if __name__ == '__main__': # pragma: no cover unittest.main() |