diff options
-rw-r--r-- | factory/__init__.py | 2 | ||||
-rw-r--r-- | factory/django.py | 36 | ||||
-rw-r--r-- | factory/helpers.py | 4 | ||||
-rw-r--r-- | tests/test_django.py | 32 |
4 files changed, 43 insertions, 31 deletions
diff --git a/factory/__init__.py b/factory/__init__.py index 251306a..b4e63be 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -79,7 +79,5 @@ from .helpers import ( lazy_attribute_sequence, container_attribute, post_generation, - - prevent_signals, ) diff --git a/factory/django.py b/factory/django.py index b502923..6f39c34 100644 --- a/factory/django.py +++ b/factory/django.py @@ -221,22 +221,22 @@ class ImageField(FileField): return thumb_io.getvalue() -class PreventSignals(object): +class mute_signals(object): """Temporarily disables and then restores any django signals. Args: *signals (django.dispatch.dispatcher.Signal): any django signals Examples: - with prevent_signals(pre_init): + with mute_signals(pre_init): user = UserFactory.build() ... - @prevent_signals(pre_save, post_save) + @mute_signals(pre_save, post_save) class UserFactory(factory.Factory): ... - @prevent_signals(post_save) + @mute_signals(post_save) def generate_users(): UserFactory.create_batch(10) """ @@ -247,7 +247,7 @@ class PreventSignals(object): def __enter__(self): for signal in self.signals: - logger.debug('PreventSignals: Disabling signal handlers %r', + logger.debug('mute_signals: Disabling signal handlers %r', signal.receivers) self.paused[signal] = signal.receivers @@ -255,22 +255,28 @@ class PreventSignals(object): def __exit__(self, exc_type, exc_value, traceback): for signal, receivers in self.paused.items(): - logger.debug('PreventSignals: Restoring signal handlers %r', + logger.debug('mute_signals: Restoring signal handlers %r', receivers) signal.receivers = receivers self.paused = {} - def __call__(self, func): - if isinstance(func, types.FunctionType): - @functools.wraps(func) + def __call__(self, callable_obj): + if isinstance(callable_obj, base.FactoryMetaClass): + generate_method = getattr(callable_obj, '_generate') + + @functools.wraps(generate_method) + def wrapped_generate(*args, **kwargs): + with self: + return generate_method(*args, **kwargs) + + callable_obj._generate = wrapped_generate + return callable_obj + + else: + @functools.wraps(callable_obj) def wrapper(*args, **kwargs): with self: - return func(*args, **kwargs) + return callable_obj(*args, **kwargs) return wrapper - generate_method = getattr(func, '_generate', None) - if generate_method: - func._generate = classmethod(self(generate_method.__func__)) - - return func
\ No newline at end of file diff --git a/factory/helpers.py b/factory/helpers.py index 719d76d..4a2a254 100644 --- a/factory/helpers.py +++ b/factory/helpers.py @@ -140,7 +140,3 @@ def container_attribute(func): def post_generation(fun): return declarations.PostGeneration(fun) - - -def prevent_signals(*signals): - return django.PreventSignals(*signals)
\ No newline at end of file diff --git a/tests/test_django.py b/tests/test_django.py index 18ffa6b..50a67a3 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -24,7 +24,6 @@ import os import factory import factory.django -from factory.helpers import prevent_signals try: @@ -532,24 +531,24 @@ class PreventSignalsTestCase(unittest.TestCase): signals.pre_save.disconnect(self.handlers.pre_save) signals.post_save.disconnect(self.handlers.post_save) - def test_signals(self): + def assertSignalsReactivated(self): WithSignalsFactory() self.assertEqual(self.handlers.pre_save.call_count, 1) self.assertEqual(self.handlers.post_save.call_count, 1) def test_context_manager(self): - with prevent_signals(signals.pre_save, signals.post_save): + with factory.django.mute_signals(signals.pre_save, signals.post_save): WithSignalsFactory() self.assertEqual(self.handlers.pre_init.call_count, 1) self.assertFalse(self.handlers.pre_save.called) self.assertFalse(self.handlers.post_save.called) - self.test_signals() + self.assertSignalsReactivated() def test_class_decorator(self): - @prevent_signals(signals.pre_save, signals.post_save) + @factory.django.mute_signals(signals.pre_save, signals.post_save) class WithSignalsDecoratedFactory(factory.django.DjangoModelFactory): FACTORY_FOR = models.WithSignals @@ -559,10 +558,23 @@ class PreventSignalsTestCase(unittest.TestCase): self.assertFalse(self.handlers.pre_save.called) self.assertFalse(self.handlers.post_save.called) - self.test_signals() + self.assertSignalsReactivated() + + def test_class_decorator_build(self): + @factory.django.mute_signals(signals.pre_save, signals.post_save) + class WithSignalsDecoratedFactory(factory.django.DjangoModelFactory): + FACTORY_FOR = models.WithSignals + + WithSignalsDecoratedFactory.build() + + self.assertEqual(self.handlers.pre_init.call_count, 1) + self.assertFalse(self.handlers.pre_save.called) + self.assertFalse(self.handlers.post_save.called) + + self.assertSignalsReactivated() def test_function_decorator(self): - @prevent_signals(signals.pre_save, signals.post_save) + @factory.django.mute_signals(signals.pre_save, signals.post_save) def foo(): WithSignalsFactory() @@ -572,12 +584,12 @@ class PreventSignalsTestCase(unittest.TestCase): self.assertFalse(self.handlers.pre_save.called) self.assertFalse(self.handlers.post_save.called) - self.test_signals() + self.assertSignalsReactivated() def test_classmethod_decorator(self): class Foo(object): @classmethod - @prevent_signals(signals.pre_save, signals.post_save) + @factory.django.mute_signals(signals.pre_save, signals.post_save) def generate(cls): WithSignalsFactory() @@ -587,7 +599,7 @@ class PreventSignalsTestCase(unittest.TestCase): self.assertFalse(self.handlers.pre_save.called) self.assertFalse(self.handlers.post_save.called) - self.test_signals() + self.assertSignalsReactivated() if __name__ == '__main__': # pragma: no cover unittest.main() |