summaryrefslogtreecommitdiff
path: root/factory/django.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/django.py')
-rw-r--r--factory/django.py130
1 files changed, 113 insertions, 17 deletions
diff --git a/factory/django.py b/factory/django.py
index fee8e52..2b6c463 100644
--- a/factory/django.py
+++ b/factory/django.py
@@ -25,6 +25,9 @@ from __future__ import absolute_import
from __future__ import unicode_literals
import os
+import types
+import logging
+import functools
"""factory_boy extensions for use with the Django framework."""
@@ -39,6 +42,9 @@ from . import base
from . import declarations
from .compat import BytesIO, is_string
+logger = logging.getLogger('factory.generate')
+
+
def require_django():
"""Simple helper to ensure Django is available."""
@@ -46,6 +52,25 @@ def require_django():
raise import_failure
+class DjangoOptions(base.FactoryOptions):
+ def _build_default_options(self):
+ return super(DjangoOptions, self)._build_default_options() + [
+ base.OptionDefault('django_get_or_create', (), inherit=True),
+ ]
+
+ def _get_counter_reference(self):
+ counter_reference = super(DjangoOptions, self)._get_counter_reference()
+ if (counter_reference == self.base_factory
+ and self.base_factory._meta.model is not None
+ and self.base_factory._meta.model._meta.abstract
+ and self.model is not None
+ and not self.model._meta.abstract):
+ # Target factory is for an abstract model, yet we're for another,
+ # concrete subclass => don't reuse the counter.
+ return self.factory
+ return counter_reference
+
+
class DjangoModelFactory(base.Factory):
"""Factory for Django models.
@@ -55,11 +80,17 @@ class DjangoModelFactory(base.Factory):
handle those for non-numerical primary keys.
"""
- ABSTRACT_FACTORY = True # Optional, but explicit.
- FACTORY_DJANGO_GET_OR_CREATE = ()
+ _options_class = DjangoOptions
+ class Meta:
+ abstract = True # Optional, but explicit.
+
+ _OLDSTYLE_ATTRIBUTES = base.Factory._OLDSTYLE_ATTRIBUTES.copy()
+ _OLDSTYLE_ATTRIBUTES.update({
+ 'FACTORY_DJANGO_GET_OR_CREATE': 'django_get_or_create',
+ })
@classmethod
- def _load_target_class(cls, definition):
+ def _load_model_class(cls, definition):
if is_string(definition) and '.' in definition:
app, model = definition.split('.', 1)
@@ -69,17 +100,20 @@ class DjangoModelFactory(base.Factory):
return definition
@classmethod
- def _get_manager(cls, target_class):
+ def _get_manager(cls, model_class):
+ if model_class is None:
+ raise base.AssociatedClassError("No model set on %s.%s.Meta"
+ % (cls.__module__, cls.__name__))
try:
- return target_class._default_manager # pylint: disable=W0212
+ return model_class._default_manager # pylint: disable=W0212
except AttributeError:
- return target_class.objects
+ return model_class.objects
@classmethod
def _setup_next_sequence(cls):
"""Compute the next available PK, based on the 'pk' database field."""
- model = cls._get_target_class() # pylint: disable=E1101
+ model = cls._get_model_class() # pylint: disable=E1101
manager = cls._get_manager(model)
try:
@@ -91,17 +125,17 @@ class DjangoModelFactory(base.Factory):
return 1
@classmethod
- def _get_or_create(cls, target_class, *args, **kwargs):
+ def _get_or_create(cls, model_class, *args, **kwargs):
"""Create an instance of the model through objects.get_or_create."""
- manager = cls._get_manager(target_class)
+ manager = cls._get_manager(model_class)
- assert 'defaults' not in cls.FACTORY_DJANGO_GET_OR_CREATE, (
+ assert 'defaults' not in cls._meta.django_get_or_create, (
"'defaults' is a reserved keyword for get_or_create "
- "(in %s.FACTORY_DJANGO_GET_OR_CREATE=%r)"
- % (cls, cls.FACTORY_DJANGO_GET_OR_CREATE))
+ "(in %s._meta.django_get_or_create=%r)"
+ % (cls, cls._meta.django_get_or_create))
key_fields = {}
- for field in cls.FACTORY_DJANGO_GET_OR_CREATE:
+ for field in cls._meta.django_get_or_create:
key_fields[field] = kwargs.pop(field)
key_fields['defaults'] = kwargs
@@ -109,12 +143,12 @@ class DjangoModelFactory(base.Factory):
return obj
@classmethod
- def _create(cls, target_class, *args, **kwargs):
+ def _create(cls, model_class, *args, **kwargs):
"""Create an instance of the model, and save it to the database."""
- manager = cls._get_manager(target_class)
+ manager = cls._get_manager(model_class)
- if cls.FACTORY_DJANGO_GET_OR_CREATE:
- return cls._get_or_create(target_class, *args, **kwargs)
+ if cls._meta.django_get_or_create:
+ return cls._get_or_create(model_class, *args, **kwargs)
return manager.create(*args, **kwargs)
@@ -214,3 +248,65 @@ class ImageField(FileField):
thumb.save(thumb_io, format=image_format)
return thumb_io.getvalue()
+
+class mute_signals(object):
+ """Temporarily disables and then restores any django signals.
+
+ Args:
+ *signals (django.dispatch.dispatcher.Signal): any django signals
+
+ Examples:
+ with mute_signals(pre_init):
+ user = UserFactory.build()
+ ...
+
+ @mute_signals(pre_save, post_save)
+ class UserFactory(factory.Factory):
+ ...
+
+ @mute_signals(post_save)
+ def generate_users():
+ UserFactory.create_batch(10)
+ """
+
+ def __init__(self, *signals):
+ self.signals = signals
+ self.paused = {}
+
+ def __enter__(self):
+ for signal in self.signals:
+ logger.debug('mute_signals: Disabling signal handlers %r',
+ signal.receivers)
+
+ self.paused[signal] = signal.receivers
+ signal.receivers = []
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ for signal, receivers in self.paused.items():
+ logger.debug('mute_signals: Restoring signal handlers %r',
+ receivers)
+
+ signal.receivers = receivers
+ self.paused = {}
+
+ def __call__(self, callable_obj):
+ if isinstance(callable_obj, base.FactoryMetaClass):
+ # Retrieve __func__, the *actual* callable object.
+ generate_method = callable_obj._generate.__func__
+
+ @classmethod
+ @functools.wraps(generate_method)
+ def wrapped_generate(*args, **kwargs):
+ with self:
+ return generate_method(*args, **kwargs)
+
+ callable_obj._generate = wrapped_generate
+ return callable_obj
+
+ else:
+ @functools.wraps(callable_obj)
+ def wrapper(*args, **kwargs):
+ with self:
+ return callable_obj(*args, **kwargs)
+ return wrapper
+