summaryrefslogtreecommitdiff
path: root/factory/django.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/django.py')
-rw-r--r--factory/django.py36
1 files changed, 21 insertions, 15 deletions
diff --git a/factory/django.py b/factory/django.py
index b502923..6f39c34 100644
--- a/factory/django.py
+++ b/factory/django.py
@@ -221,22 +221,22 @@ class ImageField(FileField):
return thumb_io.getvalue()
-class PreventSignals(object):
+class mute_signals(object):
"""Temporarily disables and then restores any django signals.
Args:
*signals (django.dispatch.dispatcher.Signal): any django signals
Examples:
- with prevent_signals(pre_init):
+ with mute_signals(pre_init):
user = UserFactory.build()
...
- @prevent_signals(pre_save, post_save)
+ @mute_signals(pre_save, post_save)
class UserFactory(factory.Factory):
...
- @prevent_signals(post_save)
+ @mute_signals(post_save)
def generate_users():
UserFactory.create_batch(10)
"""
@@ -247,7 +247,7 @@ class PreventSignals(object):
def __enter__(self):
for signal in self.signals:
- logger.debug('PreventSignals: Disabling signal handlers %r',
+ logger.debug('mute_signals: Disabling signal handlers %r',
signal.receivers)
self.paused[signal] = signal.receivers
@@ -255,22 +255,28 @@ class PreventSignals(object):
def __exit__(self, exc_type, exc_value, traceback):
for signal, receivers in self.paused.items():
- logger.debug('PreventSignals: Restoring signal handlers %r',
+ logger.debug('mute_signals: Restoring signal handlers %r',
receivers)
signal.receivers = receivers
self.paused = {}
- def __call__(self, func):
- if isinstance(func, types.FunctionType):
- @functools.wraps(func)
+ def __call__(self, callable_obj):
+ if isinstance(callable_obj, base.FactoryMetaClass):
+ generate_method = getattr(callable_obj, '_generate')
+
+ @functools.wraps(generate_method)
+ def wrapped_generate(*args, **kwargs):
+ with self:
+ return generate_method(*args, **kwargs)
+
+ callable_obj._generate = wrapped_generate
+ return callable_obj
+
+ else:
+ @functools.wraps(callable_obj)
def wrapper(*args, **kwargs):
with self:
- return func(*args, **kwargs)
+ return callable_obj(*args, **kwargs)
return wrapper
- generate_method = getattr(func, '_generate', None)
- if generate_method:
- func._generate = classmethod(self(generate_method.__func__))
-
- return func \ No newline at end of file