diff options
author | Ilya Pirogov <ilja.pirogov@gmail.com> | 2014-01-13 17:53:03 +0400 |
---|---|---|
committer | Raphaƫl Barrois <raphael.barrois@polytechnique.org> | 2014-01-21 23:11:06 +0100 |
commit | 9323fbeea374394833987cb710ac9becb7726a44 (patch) | |
tree | 881093f65a2606fae8d8aee061504583e9e5b86c /tests/test_django.py | |
parent | 402138768870871ef38b5c01ccaec957aa771d26 (diff) | |
download | factory-boy-9323fbeea374394833987cb710ac9becb7726a44.tar factory-boy-9323fbeea374394833987cb710ac9becb7726a44.tar.gz |
Added "prevent_signals" decorator/context manager
Diffstat (limited to 'tests/test_django.py')
-rw-r--r-- | tests/test_django.py | 80 |
1 files changed, 79 insertions, 1 deletions
diff --git a/tests/test_django.py b/tests/test_django.py index e4bbc2b..18ffa6b 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -24,6 +24,7 @@ import os import factory import factory.django +from factory.helpers import prevent_signals try: @@ -42,7 +43,7 @@ except ImportError: # pragma: no cover Image = None -from .compat import is_python2, unittest +from .compat import is_python2, unittest, mock from . import testdata from . import tools @@ -55,6 +56,7 @@ if django is not None: from django.db import models as django_models from django.test import simple as django_test_simple from django.test import utils as django_test_utils + from django.db.models import signals from .djapp import models else: # pragma: no cover django_test = unittest @@ -70,6 +72,7 @@ else: # pragma: no cover models.NonIntegerPk = Fake models.WithFile = Fake models.WithImage = Fake + models.WithSignals = Fake test_state = {} @@ -142,6 +145,10 @@ class WithImageFactory(factory.django.DjangoModelFactory): animage = factory.django.ImageField() +class WithSignalsFactory(factory.django.DjangoModelFactory): + FACTORY_FOR = models.WithSignals + + @unittest.skipIf(django is None, "Django not installed.") class DjangoPkSequenceTestCase(django_test.TestCase): def setUp(self): @@ -511,5 +518,76 @@ class DjangoImageFieldTestCase(unittest.TestCase): self.assertFalse(o.animage) +@unittest.skipIf(django is None, "Django not installed.") +class PreventSignalsTestCase(unittest.TestCase): + def setUp(self): + self.handlers = mock.MagicMock() + + signals.pre_init.connect(self.handlers.pre_init) + signals.pre_save.connect(self.handlers.pre_save) + signals.post_save.connect(self.handlers.post_save) + + def tearDown(self): + signals.pre_init.disconnect(self.handlers.pre_init) + signals.pre_save.disconnect(self.handlers.pre_save) + signals.post_save.disconnect(self.handlers.post_save) + + def test_signals(self): + WithSignalsFactory() + + self.assertEqual(self.handlers.pre_save.call_count, 1) + self.assertEqual(self.handlers.post_save.call_count, 1) + + def test_context_manager(self): + with prevent_signals(signals.pre_save, signals.post_save): + WithSignalsFactory() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + + def test_class_decorator(self): + @prevent_signals(signals.pre_save, signals.post_save) + class WithSignalsDecoratedFactory(factory.django.DjangoModelFactory): + FACTORY_FOR = models.WithSignals + + WithSignalsDecoratedFactory() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + + def test_function_decorator(self): + @prevent_signals(signals.pre_save, signals.post_save) + def foo(): + WithSignalsFactory() + + foo() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + + def test_classmethod_decorator(self): + class Foo(object): + @classmethod + @prevent_signals(signals.pre_save, signals.post_save) + def generate(cls): + WithSignalsFactory() + + Foo.generate() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + if __name__ == '__main__': # pragma: no cover unittest.main() |