summaryrefslogtreecommitdiff
path: root/factory/django.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/django.py')
-rw-r--r--factory/django.py166
1 files changed, 108 insertions, 58 deletions
diff --git a/factory/django.py b/factory/django.py
index a3dfdfc..b3c508c 100644
--- a/factory/django.py
+++ b/factory/django.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2010 Mark Sandstrom
-# Copyright (c) 2011-2013 Raphaël Barrois
+# Copyright (c) 2011-2015 Raphaël Barrois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -32,8 +32,10 @@ import functools
"""factory_boy extensions for use with the Django framework."""
try:
+ import django
from django.core import files as django_files
except ImportError as e: # pragma: no cover
+ django = None
django_files = None
import_failure = e
@@ -45,6 +47,8 @@ from .compat import BytesIO, is_string
logger = logging.getLogger('factory.generate')
+DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS
+
def require_django():
"""Simple helper to ensure Django is available."""
@@ -52,6 +56,56 @@ def require_django():
raise import_failure
+_LAZY_LOADS = {}
+
+def get_model(app, model):
+ """Wrapper around django's get_model."""
+ if 'get_model' not in _LAZY_LOADS:
+ _lazy_load_get_model()
+
+ _get_model = _LAZY_LOADS['get_model']
+ return _get_model(app, model)
+
+
+def _lazy_load_get_model():
+ """Lazy loading of get_model.
+
+ get_model loads django.conf.settings, which may fail if
+ the settings haven't been configured yet.
+ """
+ if django is None:
+ def get_model(app, model):
+ raise import_failure
+
+ elif django.VERSION[:2] < (1, 7):
+ from django.db.models.loading import get_model
+
+ else:
+ from django import apps as django_apps
+ get_model = django_apps.apps.get_model
+ _LAZY_LOADS['get_model'] = get_model
+
+
+class DjangoOptions(base.FactoryOptions):
+ def _build_default_options(self):
+ return super(DjangoOptions, self)._build_default_options() + [
+ base.OptionDefault('django_get_or_create', (), inherit=True),
+ base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True),
+ ]
+
+ def _get_counter_reference(self):
+ counter_reference = super(DjangoOptions, self)._get_counter_reference()
+ if (counter_reference == self.base_factory
+ and self.base_factory._meta.model is not None
+ and self.base_factory._meta.model._meta.abstract
+ and self.model is not None
+ and not self.model._meta.abstract):
+ # Target factory is for an abstract model, yet we're for another,
+ # concrete subclass => don't reuse the counter.
+ return self.factory
+ return counter_reference
+
+
class DjangoModelFactory(base.Factory):
"""Factory for Django models.
@@ -61,53 +115,48 @@ class DjangoModelFactory(base.Factory):
handle those for non-numerical primary keys.
"""
- ABSTRACT_FACTORY = True # Optional, but explicit.
- FACTORY_DJANGO_GET_OR_CREATE = ()
+ _options_class = DjangoOptions
+ class Meta:
+ abstract = True # Optional, but explicit.
@classmethod
- def _load_target_class(cls, definition):
+ def _load_model_class(cls, definition):
if is_string(definition) and '.' in definition:
app, model = definition.split('.', 1)
- from django.db.models import loading as django_loading
- return django_loading.get_model(app, model)
+ return get_model(app, model)
return definition
@classmethod
- def _get_manager(cls, target_class):
+ def _get_manager(cls, model_class):
+ if model_class is None:
+ raise base.AssociatedClassError("No model set on %s.%s.Meta"
+ % (cls.__module__, cls.__name__))
+
try:
- return target_class._default_manager # pylint: disable=W0212
+ manager = model_class.objects
except AttributeError:
- return target_class.objects
-
- @classmethod
- def _setup_next_sequence(cls):
- """Compute the next available PK, based on the 'pk' database field."""
+ # When inheriting from an abstract model with a custom
+ # manager, the class has no 'objects' field.
+ manager = model_class._default_manager
- model = cls._get_target_class() # pylint: disable=E1101
- manager = cls._get_manager(model)
-
- try:
- return 1 + manager.values_list('pk', flat=True
- ).order_by('-pk')[0]
- except (IndexError, TypeError):
- # IndexError: No instance exist yet
- # TypeError: pk isn't an integer type
- return 1
+ if cls._meta.database != DEFAULT_DB_ALIAS:
+ manager = manager.using(cls._meta.database)
+ return manager
@classmethod
- def _get_or_create(cls, target_class, *args, **kwargs):
+ def _get_or_create(cls, model_class, *args, **kwargs):
"""Create an instance of the model through objects.get_or_create."""
- manager = cls._get_manager(target_class)
+ manager = cls._get_manager(model_class)
- assert 'defaults' not in cls.FACTORY_DJANGO_GET_OR_CREATE, (
+ assert 'defaults' not in cls._meta.django_get_or_create, (
"'defaults' is a reserved keyword for get_or_create "
- "(in %s.FACTORY_DJANGO_GET_OR_CREATE=%r)"
- % (cls, cls.FACTORY_DJANGO_GET_OR_CREATE))
+ "(in %s._meta.django_get_or_create=%r)"
+ % (cls, cls._meta.django_get_or_create))
key_fields = {}
- for field in cls.FACTORY_DJANGO_GET_OR_CREATE:
+ for field in cls._meta.django_get_or_create:
key_fields[field] = kwargs.pop(field)
key_fields['defaults'] = kwargs
@@ -115,12 +164,12 @@ class DjangoModelFactory(base.Factory):
return obj
@classmethod
- def _create(cls, target_class, *args, **kwargs):
+ def _create(cls, model_class, *args, **kwargs):
"""Create an instance of the model, and save it to the database."""
- manager = cls._get_manager(target_class)
+ manager = cls._get_manager(model_class)
- if cls.FACTORY_DJANGO_GET_OR_CREATE:
- return cls._get_or_create(target_class, *args, **kwargs)
+ if cls._meta.django_get_or_create:
+ return cls._get_or_create(model_class, *args, **kwargs)
return manager.create(*args, **kwargs)
@@ -132,24 +181,22 @@ class DjangoModelFactory(base.Factory):
obj.save()
-class FileField(declarations.PostGenerationDeclaration):
+class FileField(declarations.ParameteredAttribute):
"""Helper to fill in django.db.models.FileField from a Factory."""
DEFAULT_FILENAME = 'example.dat'
+ EXTEND_CONTAINERS = True
def __init__(self, **defaults):
require_django()
- self.defaults = defaults
- super(FileField, self).__init__()
+ super(FileField, self).__init__(**defaults)
def _make_data(self, params):
"""Create data for the field."""
return params.get('data', b'')
- def _make_content(self, extraction_context):
+ def _make_content(self, params):
path = ''
- params = dict(self.defaults)
- params.update(extraction_context.extra)
if params.get('from_path') and params.get('from_file'):
raise ValueError(
@@ -157,12 +204,7 @@ class FileField(declarations.PostGenerationDeclaration):
"be non-empty when calling factory.django.FileField."
)
- if extraction_context.did_extract:
- # Should be a django.core.files.File
- content = extraction_context.value
- path = content.name
-
- elif params.get('from_path'):
+ if params.get('from_path'):
path = params['from_path']
f = open(path, 'rb')
content = django_files.File(f, name=path)
@@ -184,19 +226,13 @@ class FileField(declarations.PostGenerationDeclaration):
filename = params.get('filename', default_filename)
return filename, content
- def call(self, obj, create, extraction_context):
+ def generate(self, sequence, obj, create, params):
"""Fill in the field."""
- if extraction_context.did_extract and extraction_context.value is None:
- # User passed an empty value, don't fill
- return
- filename, content = self._make_content(extraction_context)
- field_file = getattr(obj, extraction_context.for_field)
- try:
- field_file.save(filename, content, save=create)
- finally:
- content.file.close()
- return field_file
+ params.setdefault('__sequence', sequence)
+ params = base.DictFactory.simple_generate(create, **params)
+ filename, content = self._make_content(params)
+ return django_files.File(content.file, filename)
class ImageField(FileField):
@@ -250,6 +286,9 @@ class mute_signals(object):
logger.debug('mute_signals: Disabling signal handlers %r',
signal.receivers)
+ # Note that we're using implementation details of
+ # django.signals, since arguments to signal.connect()
+ # are lost in signal.receivers
self.paused[signal] = signal.receivers
signal.receivers = []
@@ -259,8 +298,17 @@ class mute_signals(object):
receivers)
signal.receivers = receivers
+ if django.VERSION[:2] >= (1, 6):
+ with signal.lock:
+ # Django uses some caching for its signals.
+ # Since we're bypassing signal.connect and signal.disconnect,
+ # we have to keep messing with django's internals.
+ signal.sender_receivers_cache.clear()
self.paused = {}
+ def copy(self):
+ return mute_signals(*self.signals)
+
def __call__(self, callable_obj):
if isinstance(callable_obj, base.FactoryMetaClass):
# Retrieve __func__, the *actual* callable object.
@@ -269,7 +317,8 @@ class mute_signals(object):
@classmethod
@functools.wraps(generate_method)
def wrapped_generate(*args, **kwargs):
- with self:
+ # A mute_signals() object is not reentrant; use a copy everytime.
+ with self.copy():
return generate_method(*args, **kwargs)
callable_obj._generate = wrapped_generate
@@ -278,7 +327,8 @@ class mute_signals(object):
else:
@functools.wraps(callable_obj)
def wrapper(*args, **kwargs):
- with self:
+ # A mute_signals() object is not reentrant; use a copy everytime.
+ with self.copy():
return callable_obj(*args, **kwargs)
return wrapper