summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2013-08-28 01:30:15 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2013-08-28 01:30:15 +0200
commit297a111cc918c6451f1b66e3fe3572a9f3fc6b8f (patch)
tree1b7d68414a71b8e072b285d1a142da2ad3fc75af
parent7fc3e4cbdae050dcde49ea3101636ddf57d6c96d (diff)
downloadfactory-boy-297a111cc918c6451f1b66e3fe3572a9f3fc6b8f.tar
factory-boy-297a111cc918c6451f1b66e3fe3572a9f3fc6b8f.tar.gz
Allow FACTORY_FOR = 'app.Model' for Django (Closes #66).
-rw-r--r--docs/orms.rst4
-rw-r--r--factory/base.py11
-rw-r--r--factory/django.py28
-rw-r--r--tests/test_django.py29
4 files changed, 66 insertions, 6 deletions
diff --git a/docs/orms.rst b/docs/orms.rst
index 611a9ae..8215fe6 100644
--- a/docs/orms.rst
+++ b/docs/orms.rst
@@ -32,6 +32,8 @@ All factories for a Django :class:`~django.db.models.Model` should use the
This class provides the following features:
+ * The :attr:`~factory.Factory.FACTORY_FOR` attribute also supports the ``'app.Model'``
+ syntax
* :func:`~factory.Factory.create()` uses :meth:`Model.objects.create() <django.db.models.query.QuerySet.create>`
* :func:`~factory.Factory._setup_next_sequence()` selects the next unused primary key value
* When using :class:`~factory.RelatedFactory` or :class:`~factory.PostGeneration`
@@ -47,7 +49,7 @@ All factories for a Django :class:`~django.db.models.Model` should use the
.. code-block:: python
class UserFactory(factory.django.DjangoModelFactory):
- FACTORY_FOR = models.User
+ FACTORY_FOR = 'myapp.User' # Equivalent to ``FACTORY_FOR = myapp.models.User``
FACTORY_DJANGO_GET_OR_CREATE = ('username',)
username = 'john'
diff --git a/factory/base.py b/factory/base.py
index ac906de..1b9fa0d 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -331,6 +331,15 @@ class BaseFactory(object):
return kwargs
@classmethod
+ def _load_target_class(cls):
+ """Extension point for loading target classes.
+
+ This can be overridden in framework-specific subclasses to hook into
+ existing model repositories, for instance.
+ """
+ return getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS)
+
+ @classmethod
def _prepare(cls, create, **kwargs):
"""Prepare an object for this factory.
@@ -338,7 +347,7 @@ class BaseFactory(object):
create: bool, whether to create or to build the object
**kwargs: arguments to pass to the creation function
"""
- target_class = getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS)
+ target_class = cls._load_target_class()
kwargs = cls._adjust_kwargs(**kwargs)
# Remove 'hidden' arguments.
diff --git a/factory/django.py b/factory/django.py
index e3e8829..016586d 100644
--- a/factory/django.py
+++ b/factory/django.py
@@ -37,7 +37,13 @@ except ImportError as e: # pragma: no cover
from . import base
from . import declarations
-from .compat import BytesIO
+from .compat import BytesIO, is_string
+
+
+def require_django():
+ """Simple helper to ensure Django is available."""
+ if django_files is None: # pragma: no cover
+ raise import_failure
class DjangoModelFactory(base.Factory):
@@ -52,6 +58,21 @@ class DjangoModelFactory(base.Factory):
ABSTRACT_FACTORY = True # Optional, but explicit.
FACTORY_DJANGO_GET_OR_CREATE = ()
+ _associated_model = None
+
+ @classmethod
+ def _load_target_class(cls):
+ associated_class = super(DjangoModelFactory, cls)._load_target_class()
+
+ if is_string(associated_class) and '.' in associated_class:
+ app, model = associated_class.split('.', 1)
+ if cls._associated_model is None:
+ from django.db.models import loading as django_loading
+ cls._associated_model = django_loading.get_model(app, model)
+ return cls._associated_model
+
+ return associated_class
+
@classmethod
def _get_manager(cls, target_class):
try:
@@ -63,7 +84,7 @@ class DjangoModelFactory(base.Factory):
def _setup_next_sequence(cls):
"""Compute the next available PK, based on the 'pk' database field."""
- model = cls._associated_class # pylint: disable=E1101
+ model = cls._load_target_class() # pylint: disable=E1101
manager = cls._get_manager(model)
try:
@@ -116,8 +137,7 @@ class FileField(declarations.PostGenerationDeclaration):
DEFAULT_FILENAME = 'example.dat'
def __init__(self, **defaults):
- if django_files is None: # pragma: no cover
- raise import_failure
+ require_django()
self.defaults = defaults
super(FileField, self).__init__()
diff --git a/tests/test_django.py b/tests/test_django.py
index 9d02131..b27562c 100644
--- a/tests/test_django.py
+++ b/tests/test_django.py
@@ -158,6 +158,35 @@ class DjangoPkSequenceTestCase(django_test.TestCase):
@unittest.skipIf(django is None, "Django not installed.")
+class DjangoModelLoadingTestCase(django_test.TestCase):
+ """Tests FACTORY_FOR = 'app.Model' pattern."""
+
+ def test_loading(self):
+ class ExampleFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = 'djapp.StandardModel'
+
+ self.assertEqual(models.StandardModel, ExampleFactory._load_target_class())
+
+ def test_building(self):
+ class ExampleFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = 'djapp.StandardModel'
+
+ e = ExampleFactory.build()
+ self.assertEqual(models.StandardModel, e.__class__)
+
+ def test_cache(self):
+ class ExampleFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = 'djapp.StandardModel'
+
+ self.assertEqual('djapp.StandardModel', ExampleFactory._associated_class)
+ self.assertIsNone(ExampleFactory._associated_model)
+
+ self.assertEqual(models.StandardModel, ExampleFactory._load_target_class())
+ self.assertEqual('djapp.StandardModel', ExampleFactory._associated_class)
+ self.assertEqual(models.StandardModel, ExampleFactory._associated_model)
+
+
+@unittest.skipIf(django is None, "Django not installed.")
class DjangoNonIntegerPkTestCase(django_test.TestCase):
def setUp(self):
super(DjangoNonIntegerPkTestCase, self).setUp()