From ccd6a9927483b8789dee13c2d68a4407e2d37f94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Barrois?= Date: Mon, 25 Nov 2013 00:51:31 +0100 Subject: django: Fix lazy loading of 'son' factories (Closes #109). --- factory/base.py | 18 +++++++++++++----- factory/django.py | 19 +++++++------------ tests/djapp/models.py | 4 ++++ tests/test_django.py | 43 ++++++++++++++++++++++++++++++++++++------- 4 files changed, 60 insertions(+), 24 deletions(-) diff --git a/factory/base.py b/factory/base.py index 8183649..3c6571c 100644 --- a/factory/base.py +++ b/factory/base.py @@ -175,11 +175,13 @@ class FactoryMetaClass(type): is_abstract = attrs.pop('ABSTRACT_FACTORY', False) base = parent_factories[0] - inherited_associated_class = getattr(base, - CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None) + inherited_associated_class = base._get_target_class() associated_class = mcs._discover_associated_class(class_name, attrs, inherited_associated_class) + # Invoke 'lazy-loading' hooks. + associated_class = base._load_target_class(associated_class) + if associated_class is None: is_abstract = True @@ -379,13 +381,19 @@ class BaseFactory(object): return kwargs @classmethod - def _load_target_class(cls): + def _load_target_class(cls, class_definition): """Extension point for loading target classes. This can be overridden in framework-specific subclasses to hook into existing model repositories, for instance. """ - return getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS) + return class_definition + + @classmethod + def _get_target_class(cls): + """Retrieve the actual, associated target class.""" + definition = getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None) + return cls._load_target_class(definition) @classmethod def _prepare(cls, create, **kwargs): @@ -395,7 +403,7 @@ class BaseFactory(object): create: bool, whether to create or to build the object **kwargs: arguments to pass to the creation function """ - target_class = cls._load_target_class() + target_class = cls._get_target_class() kwargs = cls._adjust_kwargs(**kwargs) # Remove 'hidden' arguments. diff --git a/factory/django.py b/factory/django.py index 016586d..fee8e52 100644 --- a/factory/django.py +++ b/factory/django.py @@ -58,20 +58,15 @@ class DjangoModelFactory(base.Factory): ABSTRACT_FACTORY = True # Optional, but explicit. FACTORY_DJANGO_GET_OR_CREATE = () - _associated_model = None - @classmethod - def _load_target_class(cls): - associated_class = super(DjangoModelFactory, cls)._load_target_class() + def _load_target_class(cls, definition): - if is_string(associated_class) and '.' in associated_class: - app, model = associated_class.split('.', 1) - if cls._associated_model is None: - from django.db.models import loading as django_loading - cls._associated_model = django_loading.get_model(app, model) - return cls._associated_model + if is_string(definition) and '.' in definition: + app, model = definition.split('.', 1) + from django.db.models import loading as django_loading + return django_loading.get_model(app, model) - return associated_class + return definition @classmethod def _get_manager(cls, target_class): @@ -84,7 +79,7 @@ class DjangoModelFactory(base.Factory): def _setup_next_sequence(cls): """Compute the next available PK, based on the 'pk' database field.""" - model = cls._load_target_class() # pylint: disable=E1101 + model = cls._get_target_class() # pylint: disable=E1101 manager = cls._get_manager(model) try: diff --git a/tests/djapp/models.py b/tests/djapp/models.py index 3f25fbb..e98279d 100644 --- a/tests/djapp/models.py +++ b/tests/djapp/models.py @@ -55,6 +55,10 @@ class ConcreteSon(AbstractBase): pass +class StandardSon(StandardModel): + pass + + WITHFILE_UPLOAD_TO = 'django' WITHFILE_UPLOAD_DIR = os.path.join(settings.MEDIA_ROOT, WITHFILE_UPLOAD_TO) diff --git a/tests/test_django.py b/tests/test_django.py index 94101e9..e4bbc2b 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -64,6 +64,7 @@ else: # pragma: no cover models = Fake() models.StandardModel = Fake + models.StandardSon = None models.AbstractBase = Fake models.ConcreteSon = Fake models.NonIntegerPk = Fake @@ -211,7 +212,7 @@ class DjangoModelLoadingTestCase(django_test.TestCase): class ExampleFactory(factory.DjangoModelFactory): FACTORY_FOR = 'djapp.StandardModel' - self.assertEqual(models.StandardModel, ExampleFactory._load_target_class()) + self.assertEqual(models.StandardModel, ExampleFactory._get_target_class()) def test_building(self): class ExampleFactory(factory.DjangoModelFactory): @@ -220,16 +221,44 @@ class DjangoModelLoadingTestCase(django_test.TestCase): e = ExampleFactory.build() self.assertEqual(models.StandardModel, e.__class__) - def test_cache(self): + def test_inherited_loading(self): + """Proper loading of a model within 'child' factories. + + See https://github.com/rbarrois/factory_boy/issues/109. + """ + class ExampleFactory(factory.DjangoModelFactory): + FACTORY_FOR = 'djapp.StandardModel' + + class Example2Factory(ExampleFactory): + pass + + e = Example2Factory.build() + self.assertEqual(models.StandardModel, e.__class__) + + def test_inherited_loading_and_sequence(self): + """Proper loading of a model within 'child' factories. + + See https://github.com/rbarrois/factory_boy/issues/109. + """ class ExampleFactory(factory.DjangoModelFactory): FACTORY_FOR = 'djapp.StandardModel' - self.assertEqual('djapp.StandardModel', ExampleFactory._associated_class) - self.assertIsNone(ExampleFactory._associated_model) + foo = factory.Sequence(lambda n: n) + + class Example2Factory(ExampleFactory): + FACTORY_FOR = 'djapp.StandardSon' + + self.assertEqual(models.StandardSon, Example2Factory._get_target_class()) - self.assertEqual(models.StandardModel, ExampleFactory._load_target_class()) - self.assertEqual('djapp.StandardModel', ExampleFactory._associated_class) - self.assertEqual(models.StandardModel, ExampleFactory._associated_model) + e1 = ExampleFactory.build() + e2 = Example2Factory.build() + e3 = ExampleFactory.build() + self.assertEqual(models.StandardModel, e1.__class__) + self.assertEqual(models.StandardSon, e2.__class__) + self.assertEqual(models.StandardModel, e3.__class__) + self.assertEqual(1, e1.foo) + self.assertEqual(2, e2.foo) + self.assertEqual(3, e3.foo) @unittest.skipIf(django is None, "Django not installed.") -- cgit v1.2.3