diff options
-rw-r--r-- | docs/changelog.rst | 3 | ||||
-rw-r--r-- | factory/django.py | 13 | ||||
-rw-r--r-- | tests/djapp/models.py | 17 | ||||
-rw-r--r-- | tests/test_django.py | 12 |
4 files changed, 38 insertions, 7 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst index a7ff050..cc4a1dc 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,9 @@ ChangeLog 2.5.1 (master) -------------- +*Bugfix:* + + - Respect custom managers in :class:`~factory.django.DjangoModelFactory` (see :issue:`192`) .. _v2.5.0: diff --git a/factory/django.py b/factory/django.py index eb07bfb..ba81f13 100644 --- a/factory/django.py +++ b/factory/django.py @@ -45,6 +45,8 @@ from .compat import BytesIO, is_string logger = logging.getLogger('factory.generate') +DEFAULT_DB_ALIAS = 'default' # Same as django.db.DEFAULT_DB_ALIAS + def require_django(): """Simple helper to ensure Django is available.""" @@ -56,7 +58,7 @@ class DjangoOptions(base.FactoryOptions): def _build_default_options(self): return super(DjangoOptions, self)._build_default_options() + [ base.OptionDefault('django_get_or_create', (), inherit=True), - base.OptionDefault('database', 'default', inherit=True), + base.OptionDefault('database', DEFAULT_DB_ALIAS, inherit=True), ] def _get_counter_reference(self): @@ -100,12 +102,9 @@ class DjangoModelFactory(base.Factory): if model_class is None: raise base.AssociatedClassError("No model set on %s.%s.Meta" % (cls.__module__, cls.__name__)) - try: - manager = model_class._default_manager # pylint: disable=W0212 - except AttributeError: - manager = model_class.objects - - manager = manager.using(cls._meta.database) + manager = model_class.objects + if cls._meta.database != DEFAULT_DB_ALIAS: + manager = manager.using(cls._meta.database) return manager @classmethod diff --git a/tests/djapp/models.py b/tests/djapp/models.py index 35c765f..96ee5cf 100644 --- a/tests/djapp/models.py +++ b/tests/djapp/models.py @@ -87,3 +87,20 @@ else: class WithSignals(models.Model): foo = models.CharField(max_length=20) + + +class CustomQuerySet(models.QuerySet): + pass + + +class CustomManager(models.Manager): + + def create(self, arg=None, **kwargs): + return super(CustomManager, self).create(**kwargs) + + +class WithCustomManager(models.Model): + + foo = models.CharField(max_length=20) + + objects = CustomManager.from_queryset(CustomQuerySet)() diff --git a/tests/test_django.py b/tests/test_django.py index 2744032..9ac8f5c 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -157,6 +157,13 @@ if django is not None: model = models.WithSignals + class WithCustomManagerFactory(factory.django.DjangoModelFactory): + class Meta: + model = models.WithCustomManager + + foo = factory.Sequence(lambda n: "foo%d" % n) + + @unittest.skipIf(django is None, "Django not installed.") class ModelTests(django_test.TestCase): def test_unset_model(self): @@ -706,5 +713,10 @@ class PreventSignalsTestCase(unittest.TestCase): self.assertSignalsReactivated() +class DjangoCustomManagerTestCase(django_test.TestCase): + + def test_extra_args(self): + model = WithCustomManagerFactory(arg='foo') + if __name__ == '__main__': # pragma: no cover unittest.main() |