diff options
-rw-r--r-- | factory/__init__.py | 2 | ||||
-rw-r--r-- | factory/django.py | 60 | ||||
-rw-r--r-- | factory/helpers.py | 5 | ||||
-rw-r--r-- | tests/djapp/models.py | 4 | ||||
-rw-r--r-- | tests/test_django.py | 80 |
5 files changed, 150 insertions, 1 deletions
diff --git a/factory/__init__.py b/factory/__init__.py index b4e63be..251306a 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -79,5 +79,7 @@ from .helpers import ( lazy_attribute_sequence, container_attribute, post_generation, + + prevent_signals, ) diff --git a/factory/django.py b/factory/django.py index fee8e52..b502923 100644 --- a/factory/django.py +++ b/factory/django.py @@ -25,6 +25,9 @@ from __future__ import absolute_import from __future__ import unicode_literals import os +import types +import logging +import functools """factory_boy extensions for use with the Django framework.""" @@ -39,6 +42,9 @@ from . import base from . import declarations from .compat import BytesIO, is_string +logger = logging.getLogger('factory.generate') + + def require_django(): """Simple helper to ensure Django is available.""" @@ -214,3 +220,57 @@ class ImageField(FileField): thumb.save(thumb_io, format=image_format) return thumb_io.getvalue() + +class PreventSignals(object): + """Temporarily disables and then restores any django signals. + + Args: + *signals (django.dispatch.dispatcher.Signal): any django signals + + Examples: + with prevent_signals(pre_init): + user = UserFactory.build() + ... + + @prevent_signals(pre_save, post_save) + class UserFactory(factory.Factory): + ... + + @prevent_signals(post_save) + def generate_users(): + UserFactory.create_batch(10) + """ + + def __init__(self, *signals): + self.signals = signals + self.paused = {} + + def __enter__(self): + for signal in self.signals: + logger.debug('PreventSignals: Disabling signal handlers %r', + signal.receivers) + + self.paused[signal] = signal.receivers + signal.receivers = [] + + def __exit__(self, exc_type, exc_value, traceback): + for signal, receivers in self.paused.items(): + logger.debug('PreventSignals: Restoring signal handlers %r', + receivers) + + signal.receivers = receivers + self.paused = {} + + def __call__(self, func): + if isinstance(func, types.FunctionType): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + return wrapper + + generate_method = getattr(func, '_generate', None) + if generate_method: + func._generate = classmethod(self(generate_method.__func__)) + + return func
\ No newline at end of file diff --git a/factory/helpers.py b/factory/helpers.py index 37b41bf..719d76d 100644 --- a/factory/helpers.py +++ b/factory/helpers.py @@ -28,6 +28,7 @@ import logging from . import base from . import declarations +from . import django @contextlib.contextmanager @@ -139,3 +140,7 @@ def container_attribute(func): def post_generation(fun): return declarations.PostGeneration(fun) + + +def prevent_signals(*signals): + return django.PreventSignals(*signals)
\ No newline at end of file diff --git a/tests/djapp/models.py b/tests/djapp/models.py index e98279d..a65b50a 100644 --- a/tests/djapp/models.py +++ b/tests/djapp/models.py @@ -74,3 +74,7 @@ if Image is not None: # PIL is available else: class WithImage(models.Model): pass + + +class WithSignals(models.Model): + foo = models.CharField(max_length=20)
\ No newline at end of file diff --git a/tests/test_django.py b/tests/test_django.py index e4bbc2b..18ffa6b 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -24,6 +24,7 @@ import os import factory import factory.django +from factory.helpers import prevent_signals try: @@ -42,7 +43,7 @@ except ImportError: # pragma: no cover Image = None -from .compat import is_python2, unittest +from .compat import is_python2, unittest, mock from . import testdata from . import tools @@ -55,6 +56,7 @@ if django is not None: from django.db import models as django_models from django.test import simple as django_test_simple from django.test import utils as django_test_utils + from django.db.models import signals from .djapp import models else: # pragma: no cover django_test = unittest @@ -70,6 +72,7 @@ else: # pragma: no cover models.NonIntegerPk = Fake models.WithFile = Fake models.WithImage = Fake + models.WithSignals = Fake test_state = {} @@ -142,6 +145,10 @@ class WithImageFactory(factory.django.DjangoModelFactory): animage = factory.django.ImageField() +class WithSignalsFactory(factory.django.DjangoModelFactory): + FACTORY_FOR = models.WithSignals + + @unittest.skipIf(django is None, "Django not installed.") class DjangoPkSequenceTestCase(django_test.TestCase): def setUp(self): @@ -511,5 +518,76 @@ class DjangoImageFieldTestCase(unittest.TestCase): self.assertFalse(o.animage) +@unittest.skipIf(django is None, "Django not installed.") +class PreventSignalsTestCase(unittest.TestCase): + def setUp(self): + self.handlers = mock.MagicMock() + + signals.pre_init.connect(self.handlers.pre_init) + signals.pre_save.connect(self.handlers.pre_save) + signals.post_save.connect(self.handlers.post_save) + + def tearDown(self): + signals.pre_init.disconnect(self.handlers.pre_init) + signals.pre_save.disconnect(self.handlers.pre_save) + signals.post_save.disconnect(self.handlers.post_save) + + def test_signals(self): + WithSignalsFactory() + + self.assertEqual(self.handlers.pre_save.call_count, 1) + self.assertEqual(self.handlers.post_save.call_count, 1) + + def test_context_manager(self): + with prevent_signals(signals.pre_save, signals.post_save): + WithSignalsFactory() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + + def test_class_decorator(self): + @prevent_signals(signals.pre_save, signals.post_save) + class WithSignalsDecoratedFactory(factory.django.DjangoModelFactory): + FACTORY_FOR = models.WithSignals + + WithSignalsDecoratedFactory() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + + def test_function_decorator(self): + @prevent_signals(signals.pre_save, signals.post_save) + def foo(): + WithSignalsFactory() + + foo() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + + def test_classmethod_decorator(self): + class Foo(object): + @classmethod + @prevent_signals(signals.pre_save, signals.post_save) + def generate(cls): + WithSignalsFactory() + + Foo.generate() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.test_signals() + if __name__ == '__main__': # pragma: no cover unittest.main() |