summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2013-11-25 00:51:31 +0100
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2013-11-25 00:51:31 +0100
commitccd6a9927483b8789dee13c2d68a4407e2d37f94 (patch)
treef4824ea76bdf370f71be124dad2f61b32f666728
parent689b06c59788dedfce0af444760a3e5966761016 (diff)
downloadfactory-boy-ccd6a9927483b8789dee13c2d68a4407e2d37f94.tar
factory-boy-ccd6a9927483b8789dee13c2d68a4407e2d37f94.tar.gz
django: Fix lazy loading of 'son' factories (Closes #109).
-rw-r--r--factory/base.py18
-rw-r--r--factory/django.py19
-rw-r--r--tests/djapp/models.py4
-rw-r--r--tests/test_django.py43
4 files changed, 60 insertions, 24 deletions
diff --git a/factory/base.py b/factory/base.py
index 8183649..3c6571c 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -175,11 +175,13 @@ class FactoryMetaClass(type):
is_abstract = attrs.pop('ABSTRACT_FACTORY', False)
base = parent_factories[0]
- inherited_associated_class = getattr(base,
- CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None)
+ inherited_associated_class = base._get_target_class()
associated_class = mcs._discover_associated_class(class_name, attrs,
inherited_associated_class)
+ # Invoke 'lazy-loading' hooks.
+ associated_class = base._load_target_class(associated_class)
+
if associated_class is None:
is_abstract = True
@@ -379,13 +381,19 @@ class BaseFactory(object):
return kwargs
@classmethod
- def _load_target_class(cls):
+ def _load_target_class(cls, class_definition):
"""Extension point for loading target classes.
This can be overridden in framework-specific subclasses to hook into
existing model repositories, for instance.
"""
- return getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS)
+ return class_definition
+
+ @classmethod
+ def _get_target_class(cls):
+ """Retrieve the actual, associated target class."""
+ definition = getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None)
+ return cls._load_target_class(definition)
@classmethod
def _prepare(cls, create, **kwargs):
@@ -395,7 +403,7 @@ class BaseFactory(object):
create: bool, whether to create or to build the object
**kwargs: arguments to pass to the creation function
"""
- target_class = cls._load_target_class()
+ target_class = cls._get_target_class()
kwargs = cls._adjust_kwargs(**kwargs)
# Remove 'hidden' arguments.
diff --git a/factory/django.py b/factory/django.py
index 016586d..fee8e52 100644
--- a/factory/django.py
+++ b/factory/django.py
@@ -58,20 +58,15 @@ class DjangoModelFactory(base.Factory):
ABSTRACT_FACTORY = True # Optional, but explicit.
FACTORY_DJANGO_GET_OR_CREATE = ()
- _associated_model = None
-
@classmethod
- def _load_target_class(cls):
- associated_class = super(DjangoModelFactory, cls)._load_target_class()
+ def _load_target_class(cls, definition):
- if is_string(associated_class) and '.' in associated_class:
- app, model = associated_class.split('.', 1)
- if cls._associated_model is None:
- from django.db.models import loading as django_loading
- cls._associated_model = django_loading.get_model(app, model)
- return cls._associated_model
+ if is_string(definition) and '.' in definition:
+ app, model = definition.split('.', 1)
+ from django.db.models import loading as django_loading
+ return django_loading.get_model(app, model)
- return associated_class
+ return definition
@classmethod
def _get_manager(cls, target_class):
@@ -84,7 +79,7 @@ class DjangoModelFactory(base.Factory):
def _setup_next_sequence(cls):
"""Compute the next available PK, based on the 'pk' database field."""
- model = cls._load_target_class() # pylint: disable=E1101
+ model = cls._get_target_class() # pylint: disable=E1101
manager = cls._get_manager(model)
try:
diff --git a/tests/djapp/models.py b/tests/djapp/models.py
index 3f25fbb..e98279d 100644
--- a/tests/djapp/models.py
+++ b/tests/djapp/models.py
@@ -55,6 +55,10 @@ class ConcreteSon(AbstractBase):
pass
+class StandardSon(StandardModel):
+ pass
+
+
WITHFILE_UPLOAD_TO = 'django'
WITHFILE_UPLOAD_DIR = os.path.join(settings.MEDIA_ROOT, WITHFILE_UPLOAD_TO)
diff --git a/tests/test_django.py b/tests/test_django.py
index 94101e9..e4bbc2b 100644
--- a/tests/test_django.py
+++ b/tests/test_django.py
@@ -64,6 +64,7 @@ else: # pragma: no cover
models = Fake()
models.StandardModel = Fake
+ models.StandardSon = None
models.AbstractBase = Fake
models.ConcreteSon = Fake
models.NonIntegerPk = Fake
@@ -211,7 +212,7 @@ class DjangoModelLoadingTestCase(django_test.TestCase):
class ExampleFactory(factory.DjangoModelFactory):
FACTORY_FOR = 'djapp.StandardModel'
- self.assertEqual(models.StandardModel, ExampleFactory._load_target_class())
+ self.assertEqual(models.StandardModel, ExampleFactory._get_target_class())
def test_building(self):
class ExampleFactory(factory.DjangoModelFactory):
@@ -220,16 +221,44 @@ class DjangoModelLoadingTestCase(django_test.TestCase):
e = ExampleFactory.build()
self.assertEqual(models.StandardModel, e.__class__)
- def test_cache(self):
+ def test_inherited_loading(self):
+ """Proper loading of a model within 'child' factories.
+
+ See https://github.com/rbarrois/factory_boy/issues/109.
+ """
+ class ExampleFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = 'djapp.StandardModel'
+
+ class Example2Factory(ExampleFactory):
+ pass
+
+ e = Example2Factory.build()
+ self.assertEqual(models.StandardModel, e.__class__)
+
+ def test_inherited_loading_and_sequence(self):
+ """Proper loading of a model within 'child' factories.
+
+ See https://github.com/rbarrois/factory_boy/issues/109.
+ """
class ExampleFactory(factory.DjangoModelFactory):
FACTORY_FOR = 'djapp.StandardModel'
- self.assertEqual('djapp.StandardModel', ExampleFactory._associated_class)
- self.assertIsNone(ExampleFactory._associated_model)
+ foo = factory.Sequence(lambda n: n)
+
+ class Example2Factory(ExampleFactory):
+ FACTORY_FOR = 'djapp.StandardSon'
+
+ self.assertEqual(models.StandardSon, Example2Factory._get_target_class())
- self.assertEqual(models.StandardModel, ExampleFactory._load_target_class())
- self.assertEqual('djapp.StandardModel', ExampleFactory._associated_class)
- self.assertEqual(models.StandardModel, ExampleFactory._associated_model)
+ e1 = ExampleFactory.build()
+ e2 = Example2Factory.build()
+ e3 = ExampleFactory.build()
+ self.assertEqual(models.StandardModel, e1.__class__)
+ self.assertEqual(models.StandardSon, e2.__class__)
+ self.assertEqual(models.StandardModel, e3.__class__)
+ self.assertEqual(1, e1.foo)
+ self.assertEqual(2, e2.foo)
+ self.assertEqual(3, e3.foo)
@unittest.skipIf(django is None, "Django not installed.")