diff options
Diffstat (limited to 'factory/django.py')
-rw-r--r-- | factory/django.py | 130 |
1 files changed, 113 insertions, 17 deletions
diff --git a/factory/django.py b/factory/django.py index fee8e52..2b6c463 100644 --- a/factory/django.py +++ b/factory/django.py @@ -25,6 +25,9 @@ from __future__ import absolute_import from __future__ import unicode_literals import os +import types +import logging +import functools """factory_boy extensions for use with the Django framework.""" @@ -39,6 +42,9 @@ from . import base from . import declarations from .compat import BytesIO, is_string +logger = logging.getLogger('factory.generate') + + def require_django(): """Simple helper to ensure Django is available.""" @@ -46,6 +52,25 @@ def require_django(): raise import_failure +class DjangoOptions(base.FactoryOptions): + def _build_default_options(self): + return super(DjangoOptions, self)._build_default_options() + [ + base.OptionDefault('django_get_or_create', (), inherit=True), + ] + + def _get_counter_reference(self): + counter_reference = super(DjangoOptions, self)._get_counter_reference() + if (counter_reference == self.base_factory + and self.base_factory._meta.model is not None + and self.base_factory._meta.model._meta.abstract + and self.model is not None + and not self.model._meta.abstract): + # Target factory is for an abstract model, yet we're for another, + # concrete subclass => don't reuse the counter. + return self.factory + return counter_reference + + class DjangoModelFactory(base.Factory): """Factory for Django models. @@ -55,11 +80,17 @@ class DjangoModelFactory(base.Factory): handle those for non-numerical primary keys. """ - ABSTRACT_FACTORY = True # Optional, but explicit. - FACTORY_DJANGO_GET_OR_CREATE = () + _options_class = DjangoOptions + class Meta: + abstract = True # Optional, but explicit. + + _OLDSTYLE_ATTRIBUTES = base.Factory._OLDSTYLE_ATTRIBUTES.copy() + _OLDSTYLE_ATTRIBUTES.update({ + 'FACTORY_DJANGO_GET_OR_CREATE': 'django_get_or_create', + }) @classmethod - def _load_target_class(cls, definition): + def _load_model_class(cls, definition): if is_string(definition) and '.' in definition: app, model = definition.split('.', 1) @@ -69,17 +100,20 @@ class DjangoModelFactory(base.Factory): return definition @classmethod - def _get_manager(cls, target_class): + def _get_manager(cls, model_class): + if model_class is None: + raise base.AssociatedClassError("No model set on %s.%s.Meta" + % (cls.__module__, cls.__name__)) try: - return target_class._default_manager # pylint: disable=W0212 + return model_class._default_manager # pylint: disable=W0212 except AttributeError: - return target_class.objects + return model_class.objects @classmethod def _setup_next_sequence(cls): """Compute the next available PK, based on the 'pk' database field.""" - model = cls._get_target_class() # pylint: disable=E1101 + model = cls._get_model_class() # pylint: disable=E1101 manager = cls._get_manager(model) try: @@ -91,17 +125,17 @@ class DjangoModelFactory(base.Factory): return 1 @classmethod - def _get_or_create(cls, target_class, *args, **kwargs): + def _get_or_create(cls, model_class, *args, **kwargs): """Create an instance of the model through objects.get_or_create.""" - manager = cls._get_manager(target_class) + manager = cls._get_manager(model_class) - assert 'defaults' not in cls.FACTORY_DJANGO_GET_OR_CREATE, ( + assert 'defaults' not in cls._meta.django_get_or_create, ( "'defaults' is a reserved keyword for get_or_create " - "(in %s.FACTORY_DJANGO_GET_OR_CREATE=%r)" - % (cls, cls.FACTORY_DJANGO_GET_OR_CREATE)) + "(in %s._meta.django_get_or_create=%r)" + % (cls, cls._meta.django_get_or_create)) key_fields = {} - for field in cls.FACTORY_DJANGO_GET_OR_CREATE: + for field in cls._meta.django_get_or_create: key_fields[field] = kwargs.pop(field) key_fields['defaults'] = kwargs @@ -109,12 +143,12 @@ class DjangoModelFactory(base.Factory): return obj @classmethod - def _create(cls, target_class, *args, **kwargs): + def _create(cls, model_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" - manager = cls._get_manager(target_class) + manager = cls._get_manager(model_class) - if cls.FACTORY_DJANGO_GET_OR_CREATE: - return cls._get_or_create(target_class, *args, **kwargs) + if cls._meta.django_get_or_create: + return cls._get_or_create(model_class, *args, **kwargs) return manager.create(*args, **kwargs) @@ -214,3 +248,65 @@ class ImageField(FileField): thumb.save(thumb_io, format=image_format) return thumb_io.getvalue() + +class mute_signals(object): + """Temporarily disables and then restores any django signals. + + Args: + *signals (django.dispatch.dispatcher.Signal): any django signals + + Examples: + with mute_signals(pre_init): + user = UserFactory.build() + ... + + @mute_signals(pre_save, post_save) + class UserFactory(factory.Factory): + ... + + @mute_signals(post_save) + def generate_users(): + UserFactory.create_batch(10) + """ + + def __init__(self, *signals): + self.signals = signals + self.paused = {} + + def __enter__(self): + for signal in self.signals: + logger.debug('mute_signals: Disabling signal handlers %r', + signal.receivers) + + self.paused[signal] = signal.receivers + signal.receivers = [] + + def __exit__(self, exc_type, exc_value, traceback): + for signal, receivers in self.paused.items(): + logger.debug('mute_signals: Restoring signal handlers %r', + receivers) + + signal.receivers = receivers + self.paused = {} + + def __call__(self, callable_obj): + if isinstance(callable_obj, base.FactoryMetaClass): + # Retrieve __func__, the *actual* callable object. + generate_method = callable_obj._generate.__func__ + + @classmethod + @functools.wraps(generate_method) + def wrapped_generate(*args, **kwargs): + with self: + return generate_method(*args, **kwargs) + + callable_obj._generate = wrapped_generate + return callable_obj + + else: + @functools.wraps(callable_obj) + def wrapper(*args, **kwargs): + with self: + return callable_obj(*args, **kwargs) + return wrapper + |