aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2014-01-21 23:21:03 +0100
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2014-01-21 23:33:22 +0100
commitdccb37f551d19d9dba68d35a888941cde64f861e (patch)
tree0984f0abf19c003a58d0025c84a848cc21ed5a3d
parent9323fbeea374394833987cb710ac9becb7726a44 (diff)
downloadfactory-boy-dccb37f551d19d9dba68d35a888941cde64f861e.tar
factory-boy-dccb37f551d19d9dba68d35a888941cde64f861e.tar.gz
Improve mute_signals (Closes #122).
-rw-r--r--factory/__init__.py2
-rw-r--r--factory/django.py36
-rw-r--r--factory/helpers.py4
-rw-r--r--tests/test_django.py32
4 files changed, 43 insertions, 31 deletions
diff --git a/factory/__init__.py b/factory/__init__.py
index 251306a..b4e63be 100644
--- a/factory/__init__.py
+++ b/factory/__init__.py
@@ -79,7 +79,5 @@ from .helpers import (
lazy_attribute_sequence,
container_attribute,
post_generation,
-
- prevent_signals,
)
diff --git a/factory/django.py b/factory/django.py
index b502923..6f39c34 100644
--- a/factory/django.py
+++ b/factory/django.py
@@ -221,22 +221,22 @@ class ImageField(FileField):
return thumb_io.getvalue()
-class PreventSignals(object):
+class mute_signals(object):
"""Temporarily disables and then restores any django signals.
Args:
*signals (django.dispatch.dispatcher.Signal): any django signals
Examples:
- with prevent_signals(pre_init):
+ with mute_signals(pre_init):
user = UserFactory.build()
...
- @prevent_signals(pre_save, post_save)
+ @mute_signals(pre_save, post_save)
class UserFactory(factory.Factory):
...
- @prevent_signals(post_save)
+ @mute_signals(post_save)
def generate_users():
UserFactory.create_batch(10)
"""
@@ -247,7 +247,7 @@ class PreventSignals(object):
def __enter__(self):
for signal in self.signals:
- logger.debug('PreventSignals: Disabling signal handlers %r',
+ logger.debug('mute_signals: Disabling signal handlers %r',
signal.receivers)
self.paused[signal] = signal.receivers
@@ -255,22 +255,28 @@ class PreventSignals(object):
def __exit__(self, exc_type, exc_value, traceback):
for signal, receivers in self.paused.items():
- logger.debug('PreventSignals: Restoring signal handlers %r',
+ logger.debug('mute_signals: Restoring signal handlers %r',
receivers)
signal.receivers = receivers
self.paused = {}
- def __call__(self, func):
- if isinstance(func, types.FunctionType):
- @functools.wraps(func)
+ def __call__(self, callable_obj):
+ if isinstance(callable_obj, base.FactoryMetaClass):
+ generate_method = getattr(callable_obj, '_generate')
+
+ @functools.wraps(generate_method)
+ def wrapped_generate(*args, **kwargs):
+ with self:
+ return generate_method(*args, **kwargs)
+
+ callable_obj._generate = wrapped_generate
+ return callable_obj
+
+ else:
+ @functools.wraps(callable_obj)
def wrapper(*args, **kwargs):
with self:
- return func(*args, **kwargs)
+ return callable_obj(*args, **kwargs)
return wrapper
- generate_method = getattr(func, '_generate', None)
- if generate_method:
- func._generate = classmethod(self(generate_method.__func__))
-
- return func \ No newline at end of file
diff --git a/factory/helpers.py b/factory/helpers.py
index 719d76d..4a2a254 100644
--- a/factory/helpers.py
+++ b/factory/helpers.py
@@ -140,7 +140,3 @@ def container_attribute(func):
def post_generation(fun):
return declarations.PostGeneration(fun)
-
-
-def prevent_signals(*signals):
- return django.PreventSignals(*signals) \ No newline at end of file
diff --git a/tests/test_django.py b/tests/test_django.py
index 18ffa6b..50a67a3 100644
--- a/tests/test_django.py
+++ b/tests/test_django.py
@@ -24,7 +24,6 @@ import os
import factory
import factory.django
-from factory.helpers import prevent_signals
try:
@@ -532,24 +531,24 @@ class PreventSignalsTestCase(unittest.TestCase):
signals.pre_save.disconnect(self.handlers.pre_save)
signals.post_save.disconnect(self.handlers.post_save)
- def test_signals(self):
+ def assertSignalsReactivated(self):
WithSignalsFactory()
self.assertEqual(self.handlers.pre_save.call_count, 1)
self.assertEqual(self.handlers.post_save.call_count, 1)
def test_context_manager(self):
- with prevent_signals(signals.pre_save, signals.post_save):
+ with factory.django.mute_signals(signals.pre_save, signals.post_save):
WithSignalsFactory()
self.assertEqual(self.handlers.pre_init.call_count, 1)
self.assertFalse(self.handlers.pre_save.called)
self.assertFalse(self.handlers.post_save.called)
- self.test_signals()
+ self.assertSignalsReactivated()
def test_class_decorator(self):
- @prevent_signals(signals.pre_save, signals.post_save)
+ @factory.django.mute_signals(signals.pre_save, signals.post_save)
class WithSignalsDecoratedFactory(factory.django.DjangoModelFactory):
FACTORY_FOR = models.WithSignals
@@ -559,10 +558,23 @@ class PreventSignalsTestCase(unittest.TestCase):
self.assertFalse(self.handlers.pre_save.called)
self.assertFalse(self.handlers.post_save.called)
- self.test_signals()
+ self.assertSignalsReactivated()
+
+ def test_class_decorator_build(self):
+ @factory.django.mute_signals(signals.pre_save, signals.post_save)
+ class WithSignalsDecoratedFactory(factory.django.DjangoModelFactory):
+ FACTORY_FOR = models.WithSignals
+
+ WithSignalsDecoratedFactory.build()
+
+ self.assertEqual(self.handlers.pre_init.call_count, 1)
+ self.assertFalse(self.handlers.pre_save.called)
+ self.assertFalse(self.handlers.post_save.called)
+
+ self.assertSignalsReactivated()
def test_function_decorator(self):
- @prevent_signals(signals.pre_save, signals.post_save)
+ @factory.django.mute_signals(signals.pre_save, signals.post_save)
def foo():
WithSignalsFactory()
@@ -572,12 +584,12 @@ class PreventSignalsTestCase(unittest.TestCase):
self.assertFalse(self.handlers.pre_save.called)
self.assertFalse(self.handlers.post_save.called)
- self.test_signals()
+ self.assertSignalsReactivated()
def test_classmethod_decorator(self):
class Foo(object):
@classmethod
- @prevent_signals(signals.pre_save, signals.post_save)
+ @factory.django.mute_signals(signals.pre_save, signals.post_save)
def generate(cls):
WithSignalsFactory()
@@ -587,7 +599,7 @@ class PreventSignalsTestCase(unittest.TestCase):
self.assertFalse(self.handlers.pre_save.called)
self.assertFalse(self.handlers.post_save.called)
- self.test_signals()
+ self.assertSignalsReactivated()
if __name__ == '__main__': # pragma: no cover
unittest.main()