diff options
Diffstat (limited to 'factory/django.py')
-rw-r--r-- | factory/django.py | 120 |
1 files changed, 71 insertions, 49 deletions
diff --git a/factory/django.py b/factory/django.py index 2b6c463..b3c508c 100644 --- a/factory/django.py +++ b/factory/django.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright (c) 2010 Mark Sandstrom -# Copyright (c) 2011-2013 Raphaël Barrois +# Copyright (c) 2011-2015 Raphaël Barrois # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,8 +32,10 @@ import functools """factory_boy extensions for use with the Django framework.""" try: + import django from django.core import files as django_files except ImportError as e: # pragma: no cover + django = None django_files = None import_failure = e @@ -45,6 +47,8 @@ from .compat import BytesIO, is_string logger = logging.getLogger('factory.generate') +DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS + def require_django(): """Simple helper to ensure Django is available.""" @@ -52,10 +56,41 @@ def require_django(): raise import_failure +_LAZY_LOADS = {} + +def get_model(app, model): + """Wrapper around django's get_model.""" + if 'get_model' not in _LAZY_LOADS: + _lazy_load_get_model() + + _get_model = _LAZY_LOADS['get_model'] + return _get_model(app, model) + + +def _lazy_load_get_model(): + """Lazy loading of get_model. + + get_model loads django.conf.settings, which may fail if + the settings haven't been configured yet. + """ + if django is None: + def get_model(app, model): + raise import_failure + + elif django.VERSION[:2] < (1, 7): + from django.db.models.loading import get_model + + else: + from django import apps as django_apps + get_model = django_apps.apps.get_model + _LAZY_LOADS['get_model'] = get_model + + class DjangoOptions(base.FactoryOptions): def _build_default_options(self): return super(DjangoOptions, self)._build_default_options() + [ base.OptionDefault('django_get_or_create', (), inherit=True), + base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True), ] def _get_counter_reference(self): @@ -84,18 +119,12 @@ class DjangoModelFactory(base.Factory): class Meta: abstract = True # Optional, but explicit. - _OLDSTYLE_ATTRIBUTES = base.Factory._OLDSTYLE_ATTRIBUTES.copy() - _OLDSTYLE_ATTRIBUTES.update({ - 'FACTORY_DJANGO_GET_OR_CREATE': 'django_get_or_create', - }) - @classmethod def _load_model_class(cls, definition): if is_string(definition) and '.' in definition: app, model = definition.split('.', 1) - from django.db.models import loading as django_loading - return django_loading.get_model(app, model) + return get_model(app, model) return definition @@ -104,25 +133,17 @@ class DjangoModelFactory(base.Factory): if model_class is None: raise base.AssociatedClassError("No model set on %s.%s.Meta" % (cls.__module__, cls.__name__)) + try: - return model_class._default_manager # pylint: disable=W0212 + manager = model_class.objects except AttributeError: - return model_class.objects - - @classmethod - def _setup_next_sequence(cls): - """Compute the next available PK, based on the 'pk' database field.""" - - model = cls._get_model_class() # pylint: disable=E1101 - manager = cls._get_manager(model) + # When inheriting from an abstract model with a custom + # manager, the class has no 'objects' field. + manager = model_class._default_manager - try: - return 1 + manager.values_list('pk', flat=True - ).order_by('-pk')[0] - except (IndexError, TypeError): - # IndexError: No instance exist yet - # TypeError: pk isn't an integer type - return 1 + if cls._meta.database != DEFAULT_DB_ALIAS: + manager = manager.using(cls._meta.database) + return manager @classmethod def _get_or_create(cls, model_class, *args, **kwargs): @@ -160,24 +181,22 @@ class DjangoModelFactory(base.Factory): obj.save() -class FileField(declarations.PostGenerationDeclaration): +class FileField(declarations.ParameteredAttribute): """Helper to fill in django.db.models.FileField from a Factory.""" DEFAULT_FILENAME = 'example.dat' + EXTEND_CONTAINERS = True def __init__(self, **defaults): require_django() - self.defaults = defaults - super(FileField, self).__init__() + super(FileField, self).__init__(**defaults) def _make_data(self, params): """Create data for the field.""" return params.get('data', b'') - def _make_content(self, extraction_context): + def _make_content(self, params): path = '' - params = dict(self.defaults) - params.update(extraction_context.extra) if params.get('from_path') and params.get('from_file'): raise ValueError( @@ -185,12 +204,7 @@ class FileField(declarations.PostGenerationDeclaration): "be non-empty when calling factory.django.FileField." ) - if extraction_context.did_extract: - # Should be a django.core.files.File - content = extraction_context.value - path = content.name - - elif params.get('from_path'): + if params.get('from_path'): path = params['from_path'] f = open(path, 'rb') content = django_files.File(f, name=path) @@ -212,19 +226,13 @@ class FileField(declarations.PostGenerationDeclaration): filename = params.get('filename', default_filename) return filename, content - def call(self, obj, create, extraction_context): + def generate(self, sequence, obj, create, params): """Fill in the field.""" - if extraction_context.did_extract and extraction_context.value is None: - # User passed an empty value, don't fill - return - filename, content = self._make_content(extraction_context) - field_file = getattr(obj, extraction_context.for_field) - try: - field_file.save(filename, content, save=create) - finally: - content.file.close() - return field_file + params.setdefault('__sequence', sequence) + params = base.DictFactory.simple_generate(create, **params) + filename, content = self._make_content(params) + return django_files.File(content.file, filename) class ImageField(FileField): @@ -278,6 +286,9 @@ class mute_signals(object): logger.debug('mute_signals: Disabling signal handlers %r', signal.receivers) + # Note that we're using implementation details of + # django.signals, since arguments to signal.connect() + # are lost in signal.receivers self.paused[signal] = signal.receivers signal.receivers = [] @@ -287,8 +298,17 @@ class mute_signals(object): receivers) signal.receivers = receivers + if django.VERSION[:2] >= (1, 6): + with signal.lock: + # Django uses some caching for its signals. + # Since we're bypassing signal.connect and signal.disconnect, + # we have to keep messing with django's internals. + signal.sender_receivers_cache.clear() self.paused = {} + def copy(self): + return mute_signals(*self.signals) + def __call__(self, callable_obj): if isinstance(callable_obj, base.FactoryMetaClass): # Retrieve __func__, the *actual* callable object. @@ -297,7 +317,8 @@ class mute_signals(object): @classmethod @functools.wraps(generate_method) def wrapped_generate(*args, **kwargs): - with self: + # A mute_signals() object is not reentrant; use a copy everytime. + with self.copy(): return generate_method(*args, **kwargs) callable_obj._generate = wrapped_generate @@ -306,7 +327,8 @@ class mute_signals(object): else: @functools.wraps(callable_obj) def wrapper(*args, **kwargs): - with self: + # A mute_signals() object is not reentrant; use a copy everytime. + with self.copy(): return callable_obj(*args, **kwargs) return wrapper |