diff options
Diffstat (limited to 'factory/django.py')
-rw-r--r-- | factory/django.py | 166 |
1 files changed, 108 insertions, 58 deletions
diff --git a/factory/django.py b/factory/django.py index a3dfdfc..b3c508c 100644 --- a/factory/django.py +++ b/factory/django.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright (c) 2010 Mark Sandstrom -# Copyright (c) 2011-2013 Raphaël Barrois +# Copyright (c) 2011-2015 Raphaël Barrois # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -32,8 +32,10 @@ import functools """factory_boy extensions for use with the Django framework.""" try: + import django from django.core import files as django_files except ImportError as e: # pragma: no cover + django = None django_files = None import_failure = e @@ -45,6 +47,8 @@ from .compat import BytesIO, is_string logger = logging.getLogger('factory.generate') +DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS + def require_django(): """Simple helper to ensure Django is available.""" @@ -52,6 +56,56 @@ def require_django(): raise import_failure +_LAZY_LOADS = {} + +def get_model(app, model): + """Wrapper around django's get_model.""" + if 'get_model' not in _LAZY_LOADS: + _lazy_load_get_model() + + _get_model = _LAZY_LOADS['get_model'] + return _get_model(app, model) + + +def _lazy_load_get_model(): + """Lazy loading of get_model. + + get_model loads django.conf.settings, which may fail if + the settings haven't been configured yet. + """ + if django is None: + def get_model(app, model): + raise import_failure + + elif django.VERSION[:2] < (1, 7): + from django.db.models.loading import get_model + + else: + from django import apps as django_apps + get_model = django_apps.apps.get_model + _LAZY_LOADS['get_model'] = get_model + + +class DjangoOptions(base.FactoryOptions): + def _build_default_options(self): + return super(DjangoOptions, self)._build_default_options() + [ + base.OptionDefault('django_get_or_create', (), inherit=True), + base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True), + ] + + def _get_counter_reference(self): + counter_reference = super(DjangoOptions, self)._get_counter_reference() + if (counter_reference == self.base_factory + and self.base_factory._meta.model is not None + and self.base_factory._meta.model._meta.abstract + and self.model is not None + and not self.model._meta.abstract): + # Target factory is for an abstract model, yet we're for another, + # concrete subclass => don't reuse the counter. + return self.factory + return counter_reference + + class DjangoModelFactory(base.Factory): """Factory for Django models. @@ -61,53 +115,48 @@ class DjangoModelFactory(base.Factory): handle those for non-numerical primary keys. """ - ABSTRACT_FACTORY = True # Optional, but explicit. - FACTORY_DJANGO_GET_OR_CREATE = () + _options_class = DjangoOptions + class Meta: + abstract = True # Optional, but explicit. @classmethod - def _load_target_class(cls, definition): + def _load_model_class(cls, definition): if is_string(definition) and '.' in definition: app, model = definition.split('.', 1) - from django.db.models import loading as django_loading - return django_loading.get_model(app, model) + return get_model(app, model) return definition @classmethod - def _get_manager(cls, target_class): + def _get_manager(cls, model_class): + if model_class is None: + raise base.AssociatedClassError("No model set on %s.%s.Meta" + % (cls.__module__, cls.__name__)) + try: - return target_class._default_manager # pylint: disable=W0212 + manager = model_class.objects except AttributeError: - return target_class.objects - - @classmethod - def _setup_next_sequence(cls): - """Compute the next available PK, based on the 'pk' database field.""" + # When inheriting from an abstract model with a custom + # manager, the class has no 'objects' field. + manager = model_class._default_manager - model = cls._get_target_class() # pylint: disable=E1101 - manager = cls._get_manager(model) - - try: - return 1 + manager.values_list('pk', flat=True - ).order_by('-pk')[0] - except (IndexError, TypeError): - # IndexError: No instance exist yet - # TypeError: pk isn't an integer type - return 1 + if cls._meta.database != DEFAULT_DB_ALIAS: + manager = manager.using(cls._meta.database) + return manager @classmethod - def _get_or_create(cls, target_class, *args, **kwargs): + def _get_or_create(cls, model_class, *args, **kwargs): """Create an instance of the model through objects.get_or_create.""" - manager = cls._get_manager(target_class) + manager = cls._get_manager(model_class) - assert 'defaults' not in cls.FACTORY_DJANGO_GET_OR_CREATE, ( + assert 'defaults' not in cls._meta.django_get_or_create, ( "'defaults' is a reserved keyword for get_or_create " - "(in %s.FACTORY_DJANGO_GET_OR_CREATE=%r)" - % (cls, cls.FACTORY_DJANGO_GET_OR_CREATE)) + "(in %s._meta.django_get_or_create=%r)" + % (cls, cls._meta.django_get_or_create)) key_fields = {} - for field in cls.FACTORY_DJANGO_GET_OR_CREATE: + for field in cls._meta.django_get_or_create: key_fields[field] = kwargs.pop(field) key_fields['defaults'] = kwargs @@ -115,12 +164,12 @@ class DjangoModelFactory(base.Factory): return obj @classmethod - def _create(cls, target_class, *args, **kwargs): + def _create(cls, model_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" - manager = cls._get_manager(target_class) + manager = cls._get_manager(model_class) - if cls.FACTORY_DJANGO_GET_OR_CREATE: - return cls._get_or_create(target_class, *args, **kwargs) + if cls._meta.django_get_or_create: + return cls._get_or_create(model_class, *args, **kwargs) return manager.create(*args, **kwargs) @@ -132,24 +181,22 @@ class DjangoModelFactory(base.Factory): obj.save() -class FileField(declarations.PostGenerationDeclaration): +class FileField(declarations.ParameteredAttribute): """Helper to fill in django.db.models.FileField from a Factory.""" DEFAULT_FILENAME = 'example.dat' + EXTEND_CONTAINERS = True def __init__(self, **defaults): require_django() - self.defaults = defaults - super(FileField, self).__init__() + super(FileField, self).__init__(**defaults) def _make_data(self, params): """Create data for the field.""" return params.get('data', b'') - def _make_content(self, extraction_context): + def _make_content(self, params): path = '' - params = dict(self.defaults) - params.update(extraction_context.extra) if params.get('from_path') and params.get('from_file'): raise ValueError( @@ -157,12 +204,7 @@ class FileField(declarations.PostGenerationDeclaration): "be non-empty when calling factory.django.FileField." ) - if extraction_context.did_extract: - # Should be a django.core.files.File - content = extraction_context.value - path = content.name - - elif params.get('from_path'): + if params.get('from_path'): path = params['from_path'] f = open(path, 'rb') content = django_files.File(f, name=path) @@ -184,19 +226,13 @@ class FileField(declarations.PostGenerationDeclaration): filename = params.get('filename', default_filename) return filename, content - def call(self, obj, create, extraction_context): + def generate(self, sequence, obj, create, params): """Fill in the field.""" - if extraction_context.did_extract and extraction_context.value is None: - # User passed an empty value, don't fill - return - filename, content = self._make_content(extraction_context) - field_file = getattr(obj, extraction_context.for_field) - try: - field_file.save(filename, content, save=create) - finally: - content.file.close() - return field_file + params.setdefault('__sequence', sequence) + params = base.DictFactory.simple_generate(create, **params) + filename, content = self._make_content(params) + return django_files.File(content.file, filename) class ImageField(FileField): @@ -250,6 +286,9 @@ class mute_signals(object): logger.debug('mute_signals: Disabling signal handlers %r', signal.receivers) + # Note that we're using implementation details of + # django.signals, since arguments to signal.connect() + # are lost in signal.receivers self.paused[signal] = signal.receivers signal.receivers = [] @@ -259,8 +298,17 @@ class mute_signals(object): receivers) signal.receivers = receivers + if django.VERSION[:2] >= (1, 6): + with signal.lock: + # Django uses some caching for its signals. + # Since we're bypassing signal.connect and signal.disconnect, + # we have to keep messing with django's internals. + signal.sender_receivers_cache.clear() self.paused = {} + def copy(self): + return mute_signals(*self.signals) + def __call__(self, callable_obj): if isinstance(callable_obj, base.FactoryMetaClass): # Retrieve __func__, the *actual* callable object. @@ -269,7 +317,8 @@ class mute_signals(object): @classmethod @functools.wraps(generate_method) def wrapped_generate(*args, **kwargs): - with self: + # A mute_signals() object is not reentrant; use a copy everytime. + with self.copy(): return generate_method(*args, **kwargs) callable_obj._generate = wrapped_generate @@ -278,7 +327,8 @@ class mute_signals(object): else: @functools.wraps(callable_obj) def wrapper(*args, **kwargs): - with self: + # A mute_signals() object is not reentrant; use a copy everytime. + with self.copy(): return callable_obj(*args, **kwargs) return wrapper |