From dccb37f551d19d9dba68d35a888941cde64f861e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Barrois?= Date: Tue, 21 Jan 2014 23:21:03 +0100 Subject: Improve mute_signals (Closes #122). --- factory/django.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) (limited to 'factory/django.py') diff --git a/factory/django.py b/factory/django.py index b502923..6f39c34 100644 --- a/factory/django.py +++ b/factory/django.py @@ -221,22 +221,22 @@ class ImageField(FileField): return thumb_io.getvalue() -class PreventSignals(object): +class mute_signals(object): """Temporarily disables and then restores any django signals. Args: *signals (django.dispatch.dispatcher.Signal): any django signals Examples: - with prevent_signals(pre_init): + with mute_signals(pre_init): user = UserFactory.build() ... - @prevent_signals(pre_save, post_save) + @mute_signals(pre_save, post_save) class UserFactory(factory.Factory): ... - @prevent_signals(post_save) + @mute_signals(post_save) def generate_users(): UserFactory.create_batch(10) """ @@ -247,7 +247,7 @@ class PreventSignals(object): def __enter__(self): for signal in self.signals: - logger.debug('PreventSignals: Disabling signal handlers %r', + logger.debug('mute_signals: Disabling signal handlers %r', signal.receivers) self.paused[signal] = signal.receivers @@ -255,22 +255,28 @@ class PreventSignals(object): def __exit__(self, exc_type, exc_value, traceback): for signal, receivers in self.paused.items(): - logger.debug('PreventSignals: Restoring signal handlers %r', + logger.debug('mute_signals: Restoring signal handlers %r', receivers) signal.receivers = receivers self.paused = {} - def __call__(self, func): - if isinstance(func, types.FunctionType): - @functools.wraps(func) + def __call__(self, callable_obj): + if isinstance(callable_obj, base.FactoryMetaClass): + generate_method = getattr(callable_obj, '_generate') + + @functools.wraps(generate_method) + def wrapped_generate(*args, **kwargs): + with self: + return generate_method(*args, **kwargs) + + callable_obj._generate = wrapped_generate + return callable_obj + + else: + @functools.wraps(callable_obj) def wrapper(*args, **kwargs): with self: - return func(*args, **kwargs) + return callable_obj(*args, **kwargs) return wrapper - generate_method = getattr(func, '_generate', None) - if generate_method: - func._generate = classmethod(self(generate_method.__func__)) - - return func \ No newline at end of file -- cgit v1.2.3