diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-08-28 01:30:15 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-08-28 01:30:15 +0200 |
commit | 297a111cc918c6451f1b66e3fe3572a9f3fc6b8f (patch) | |
tree | 1b7d68414a71b8e072b285d1a142da2ad3fc75af | |
parent | 7fc3e4cbdae050dcde49ea3101636ddf57d6c96d (diff) | |
download | factory-boy-297a111cc918c6451f1b66e3fe3572a9f3fc6b8f.tar factory-boy-297a111cc918c6451f1b66e3fe3572a9f3fc6b8f.tar.gz |
Allow FACTORY_FOR = 'app.Model' for Django (Closes #66).
-rw-r--r-- | docs/orms.rst | 4 | ||||
-rw-r--r-- | factory/base.py | 11 | ||||
-rw-r--r-- | factory/django.py | 28 | ||||
-rw-r--r-- | tests/test_django.py | 29 |
4 files changed, 66 insertions, 6 deletions
diff --git a/docs/orms.rst b/docs/orms.rst index 611a9ae..8215fe6 100644 --- a/docs/orms.rst +++ b/docs/orms.rst @@ -32,6 +32,8 @@ All factories for a Django :class:`~django.db.models.Model` should use the This class provides the following features: + * The :attr:`~factory.Factory.FACTORY_FOR` attribute also supports the ``'app.Model'`` + syntax * :func:`~factory.Factory.create()` uses :meth:`Model.objects.create() <django.db.models.query.QuerySet.create>` * :func:`~factory.Factory._setup_next_sequence()` selects the next unused primary key value * When using :class:`~factory.RelatedFactory` or :class:`~factory.PostGeneration` @@ -47,7 +49,7 @@ All factories for a Django :class:`~django.db.models.Model` should use the .. code-block:: python class UserFactory(factory.django.DjangoModelFactory): - FACTORY_FOR = models.User + FACTORY_FOR = 'myapp.User' # Equivalent to ``FACTORY_FOR = myapp.models.User`` FACTORY_DJANGO_GET_OR_CREATE = ('username',) username = 'john' diff --git a/factory/base.py b/factory/base.py index ac906de..1b9fa0d 100644 --- a/factory/base.py +++ b/factory/base.py @@ -331,6 +331,15 @@ class BaseFactory(object): return kwargs @classmethod + def _load_target_class(cls): + """Extension point for loading target classes. + + This can be overridden in framework-specific subclasses to hook into + existing model repositories, for instance. + """ + return getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS) + + @classmethod def _prepare(cls, create, **kwargs): """Prepare an object for this factory. @@ -338,7 +347,7 @@ class BaseFactory(object): create: bool, whether to create or to build the object **kwargs: arguments to pass to the creation function """ - target_class = getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS) + target_class = cls._load_target_class() kwargs = cls._adjust_kwargs(**kwargs) # Remove 'hidden' arguments. diff --git a/factory/django.py b/factory/django.py index e3e8829..016586d 100644 --- a/factory/django.py +++ b/factory/django.py @@ -37,7 +37,13 @@ except ImportError as e: # pragma: no cover from . import base from . import declarations -from .compat import BytesIO +from .compat import BytesIO, is_string + + +def require_django(): + """Simple helper to ensure Django is available.""" + if django_files is None: # pragma: no cover + raise import_failure class DjangoModelFactory(base.Factory): @@ -52,6 +58,21 @@ class DjangoModelFactory(base.Factory): ABSTRACT_FACTORY = True # Optional, but explicit. FACTORY_DJANGO_GET_OR_CREATE = () + _associated_model = None + + @classmethod + def _load_target_class(cls): + associated_class = super(DjangoModelFactory, cls)._load_target_class() + + if is_string(associated_class) and '.' in associated_class: + app, model = associated_class.split('.', 1) + if cls._associated_model is None: + from django.db.models import loading as django_loading + cls._associated_model = django_loading.get_model(app, model) + return cls._associated_model + + return associated_class + @classmethod def _get_manager(cls, target_class): try: @@ -63,7 +84,7 @@ class DjangoModelFactory(base.Factory): def _setup_next_sequence(cls): """Compute the next available PK, based on the 'pk' database field.""" - model = cls._associated_class # pylint: disable=E1101 + model = cls._load_target_class() # pylint: disable=E1101 manager = cls._get_manager(model) try: @@ -116,8 +137,7 @@ class FileField(declarations.PostGenerationDeclaration): DEFAULT_FILENAME = 'example.dat' def __init__(self, **defaults): - if django_files is None: # pragma: no cover - raise import_failure + require_django() self.defaults = defaults super(FileField, self).__init__() diff --git a/tests/test_django.py b/tests/test_django.py index 9d02131..b27562c 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -158,6 +158,35 @@ class DjangoPkSequenceTestCase(django_test.TestCase): @unittest.skipIf(django is None, "Django not installed.") +class DjangoModelLoadingTestCase(django_test.TestCase): + """Tests FACTORY_FOR = 'app.Model' pattern.""" + + def test_loading(self): + class ExampleFactory(factory.DjangoModelFactory): + FACTORY_FOR = 'djapp.StandardModel' + + self.assertEqual(models.StandardModel, ExampleFactory._load_target_class()) + + def test_building(self): + class ExampleFactory(factory.DjangoModelFactory): + FACTORY_FOR = 'djapp.StandardModel' + + e = ExampleFactory.build() + self.assertEqual(models.StandardModel, e.__class__) + + def test_cache(self): + class ExampleFactory(factory.DjangoModelFactory): + FACTORY_FOR = 'djapp.StandardModel' + + self.assertEqual('djapp.StandardModel', ExampleFactory._associated_class) + self.assertIsNone(ExampleFactory._associated_model) + + self.assertEqual(models.StandardModel, ExampleFactory._load_target_class()) + self.assertEqual('djapp.StandardModel', ExampleFactory._associated_class) + self.assertEqual(models.StandardModel, ExampleFactory._associated_model) + + +@unittest.skipIf(django is None, "Django not installed.") class DjangoNonIntegerPkTestCase(django_test.TestCase): def setUp(self): super(DjangoNonIntegerPkTestCase, self).setUp() |