summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/changelog.rst4
-rw-r--r--docs/orms.rst33
-rw-r--r--factory/base.py15
-rw-r--r--tests/test_using.py133
4 files changed, 180 insertions, 5 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 9a4a64e..4b24797 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -13,8 +13,8 @@ ChangeLog
- The default :attr:`~factory.Sequence.type` for :class:`~factory.Sequence` is now :obj:`int`
- Fields listed in :attr:`~factory.Factory.FACTORY_HIDDEN_ARGS` won't be passed to
the associated class' constructor
-
- - Add support for ``get_or_create`` in :class:`~factory.DjangoModelFactory`
+ - Add support for ``get_or_create`` in :class:`~factory.DjangoModelFactory`,
+ through :attr:`~factory.DjangoModelFactory.FACTORY_DJANGO_GET_OR_CREATE`.
*Removed:*
diff --git a/docs/orms.rst b/docs/orms.rst
index d6ff3c3..8e5b6f6 100644
--- a/docs/orms.rst
+++ b/docs/orms.rst
@@ -35,3 +35,36 @@ All factories for a Django :class:`~django.db.models.Model` should use the
* When using :class:`~factory.RelatedFactory` or :class:`~factory.PostGeneration`
attributes, the base object will be :meth:`saved <django.db.models.Model.save>`
once all post-generation hooks have run.
+
+ .. attribute:: FACTORY_DJANGO_GET_OR_CREATE
+
+ Fields whose name are passed in this list will be used to perform a
+ :meth:`Model.objects.get_or_create() <django.db.models.query.QuerySet.get_or_create>`
+ instead of the usual :meth:`Model.objects.create() <django.db.models.query.QuerySet.create>`:
+
+ .. code-block:: python
+
+ class UserFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = models.User
+ FACTORY_DJANGO_GET_OR_CREATE = ('username',)
+
+ username = 'john'
+
+ .. code-block:: pycon
+
+ >>> User.objects.all()
+ []
+ >>> UserFactory() # Creates a new user
+ <User: john>
+ >>> User.objects.all()
+ [<User: john>]
+
+ >>> UserFactory() # Fetches the existing user
+ <User: john>
+ >>> User.objects.all() # No new user!
+ [<User: john>]
+
+ >>> UserFactory(username='jack') # Creates another user
+ <User: jack>
+ >>> User.objects.all()
+ [<User: john>, <User: jack>]
diff --git a/factory/base.py b/factory/base.py
index 58cd50b..71c2eb1 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -554,6 +554,7 @@ class DjangoModelFactory(Factory):
"""
ABSTRACT_FACTORY = True
+ FACTORY_DJANGO_GET_OR_CREATE = ()
@classmethod
def _get_manager(cls, target_class):
@@ -578,7 +579,19 @@ class DjangoModelFactory(Factory):
def _create(cls, target_class, *args, **kwargs):
"""Create an instance of the model, and save it to the database."""
manager = cls._get_manager(target_class)
- return manager.create(*args, **kwargs)
+
+ assert 'defaults' not in cls.FACTORY_DJANGO_GET_OR_CREATE, (
+ "'defaults' is a reserved keyword for get_or_create "
+ "(in %s.FACTORY_DJANGO_GET_OR_CREATE=%r)"
+ % (cls, cls.FACTORY_DJANGO_GET_OR_CREATE))
+
+ key_fields = {}
+ for field in cls.FACTORY_DJANGO_GET_OR_CREATE:
+ key_fields[field] = kwargs.pop(field)
+ key_fields['defaults'] = kwargs
+
+ obj, created = manager.get_or_create(*args, **key_fields)
+ return obj
@classmethod
def _after_postgeneration(cls, obj, create, results=None):
diff --git a/tests/test_using.py b/tests/test_using.py
index 41b666f..65cb7a5 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -49,10 +49,11 @@ class FakeModel(object):
return instance
class FakeModelManager(object):
- def create(self, **kwargs):
+ def get_or_create(self, **kwargs):
+ kwargs.update(kwargs.pop('defaults', {}))
instance = FakeModel.create(**kwargs)
instance.id = 2
- return instance
+ return instance, True
def values_list(self, *args, **kwargs):
return self
@@ -1237,6 +1238,134 @@ class IteratorTestCase(unittest.TestCase):
self.assertEqual(i + 10, obj.one)
+class BetterFakeModelManager(object):
+ def __init__(self, keys, instance):
+ self.keys = keys
+ self.instance = instance
+
+ def get_or_create(self, **kwargs):
+ defaults = kwargs.pop('defaults', {})
+ if kwargs == self.keys:
+ return self.instance, False
+ kwargs.update(defaults)
+ instance = FakeModel.create(**kwargs)
+ instance.id = 2
+ return instance, True
+
+ def values_list(self, *args, **kwargs):
+ return self
+
+ def order_by(self, *args, **kwargs):
+ return [1]
+
+
+class BetterFakeModel(object):
+ @classmethod
+ def create(cls, **kwargs):
+ instance = cls(**kwargs)
+ instance.id = 1
+ return instance
+
+ def __init__(self, **kwargs):
+ for name, value in kwargs.items():
+ setattr(self, name, value)
+ self.id = None
+
+
+class DjangoModelFactoryTestCase(unittest.TestCase):
+ def test_simple(self):
+ class FakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = FakeModel
+
+ obj = FakeModelFactory(one=1)
+ self.assertEqual(1, obj.one)
+ self.assertEqual(2, obj.id)
+
+ def test_existing_instance(self):
+ prev = BetterFakeModel.create(x=1, y=2, z=3)
+ prev.id = 42
+
+ class MyFakeModel(BetterFakeModel):
+ objects = BetterFakeModelManager({'x': 1}, prev)
+
+ class MyFakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = MyFakeModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('x',)
+ x = 1
+ y = 4
+ z = 6
+
+ obj = MyFakeModelFactory()
+ self.assertEqual(prev, obj)
+ self.assertEqual(1, obj.x)
+ self.assertEqual(2, obj.y)
+ self.assertEqual(3, obj.z)
+ self.assertEqual(42, obj.id)
+
+ def test_existing_instance_complex_key(self):
+ prev = BetterFakeModel.create(x=1, y=2, z=3)
+ prev.id = 42
+
+ class MyFakeModel(BetterFakeModel):
+ objects = BetterFakeModelManager({'x': 1, 'y': 2, 'z': 3}, prev)
+
+ class MyFakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = MyFakeModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('x', 'y', 'z')
+ x = 1
+ y = 4
+ z = 6
+
+ obj = MyFakeModelFactory(y=2, z=3)
+ self.assertEqual(prev, obj)
+ self.assertEqual(1, obj.x)
+ self.assertEqual(2, obj.y)
+ self.assertEqual(3, obj.z)
+ self.assertEqual(42, obj.id)
+
+ def test_new_instance(self):
+ prev = BetterFakeModel.create(x=1, y=2, z=3)
+ prev.id = 42
+
+ class MyFakeModel(BetterFakeModel):
+ objects = BetterFakeModelManager({'x': 1}, prev)
+
+ class MyFakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = MyFakeModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('x',)
+ x = 1
+ y = 4
+ z = 6
+
+ obj = MyFakeModelFactory(x=2)
+ self.assertNotEqual(prev, obj)
+ self.assertEqual(2, obj.x)
+ self.assertEqual(4, obj.y)
+ self.assertEqual(6, obj.z)
+ self.assertEqual(2, obj.id)
+
+ def test_new_instance_complex_key(self):
+ prev = BetterFakeModel.create(x=1, y=2, z=3)
+ prev.id = 42
+
+ class MyFakeModel(BetterFakeModel):
+ objects = BetterFakeModelManager({'x': 1, 'y': 2, 'z': 3}, prev)
+
+ class MyFakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = MyFakeModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('x', 'y', 'z')
+ x = 1
+ y = 4
+ z = 6
+
+ obj = MyFakeModelFactory(y=2, z=4)
+ self.assertNotEqual(prev, obj)
+ self.assertEqual(1, obj.x)
+ self.assertEqual(2, obj.y)
+ self.assertEqual(4, obj.z)
+ self.assertEqual(2, obj.id)
+
+
class PostGenerationTestCase(unittest.TestCase):
def test_post_generation(self):
class TestObjectFactory(factory.Factory):