From 6532f25058a13e81b1365bb353848510821f571f Mon Sep 17 00:00:00 2001 From: Raphaƫl Barrois Date: Tue, 2 Apr 2013 23:34:28 +0200 Subject: Add support for get_or_create in DjangoModelFactory. --- docs/changelog.rst | 4 +- docs/orms.rst | 33 +++++++++++++ factory/base.py | 15 +++++- tests/test_using.py | 133 +++++++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 180 insertions(+), 5 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9a4a64e..4b24797 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -13,8 +13,8 @@ ChangeLog - The default :attr:`~factory.Sequence.type` for :class:`~factory.Sequence` is now :obj:`int` - Fields listed in :attr:`~factory.Factory.FACTORY_HIDDEN_ARGS` won't be passed to the associated class' constructor - - - Add support for ``get_or_create`` in :class:`~factory.DjangoModelFactory` + - Add support for ``get_or_create`` in :class:`~factory.DjangoModelFactory`, + through :attr:`~factory.DjangoModelFactory.FACTORY_DJANGO_GET_OR_CREATE`. *Removed:* diff --git a/docs/orms.rst b/docs/orms.rst index d6ff3c3..8e5b6f6 100644 --- a/docs/orms.rst +++ b/docs/orms.rst @@ -35,3 +35,36 @@ All factories for a Django :class:`~django.db.models.Model` should use the * When using :class:`~factory.RelatedFactory` or :class:`~factory.PostGeneration` attributes, the base object will be :meth:`saved ` once all post-generation hooks have run. + + .. attribute:: FACTORY_DJANGO_GET_OR_CREATE + + Fields whose name are passed in this list will be used to perform a + :meth:`Model.objects.get_or_create() ` + instead of the usual :meth:`Model.objects.create() `: + + .. code-block:: python + + class UserFactory(factory.DjangoModelFactory): + FACTORY_FOR = models.User + FACTORY_DJANGO_GET_OR_CREATE = ('username',) + + username = 'john' + + .. code-block:: pycon + + >>> User.objects.all() + [] + >>> UserFactory() # Creates a new user + + >>> User.objects.all() + [] + + >>> UserFactory() # Fetches the existing user + + >>> User.objects.all() # No new user! + [] + + >>> UserFactory(username='jack') # Creates another user + + >>> User.objects.all() + [, ] diff --git a/factory/base.py b/factory/base.py index 58cd50b..71c2eb1 100644 --- a/factory/base.py +++ b/factory/base.py @@ -554,6 +554,7 @@ class DjangoModelFactory(Factory): """ ABSTRACT_FACTORY = True + FACTORY_DJANGO_GET_OR_CREATE = () @classmethod def _get_manager(cls, target_class): @@ -578,7 +579,19 @@ class DjangoModelFactory(Factory): def _create(cls, target_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" manager = cls._get_manager(target_class) - return manager.create(*args, **kwargs) + + assert 'defaults' not in cls.FACTORY_DJANGO_GET_OR_CREATE, ( + "'defaults' is a reserved keyword for get_or_create " + "(in %s.FACTORY_DJANGO_GET_OR_CREATE=%r)" + % (cls, cls.FACTORY_DJANGO_GET_OR_CREATE)) + + key_fields = {} + for field in cls.FACTORY_DJANGO_GET_OR_CREATE: + key_fields[field] = kwargs.pop(field) + key_fields['defaults'] = kwargs + + obj, created = manager.get_or_create(*args, **key_fields) + return obj @classmethod def _after_postgeneration(cls, obj, create, results=None): diff --git a/tests/test_using.py b/tests/test_using.py index 41b666f..65cb7a5 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -49,10 +49,11 @@ class FakeModel(object): return instance class FakeModelManager(object): - def create(self, **kwargs): + def get_or_create(self, **kwargs): + kwargs.update(kwargs.pop('defaults', {})) instance = FakeModel.create(**kwargs) instance.id = 2 - return instance + return instance, True def values_list(self, *args, **kwargs): return self @@ -1237,6 +1238,134 @@ class IteratorTestCase(unittest.TestCase): self.assertEqual(i + 10, obj.one) +class BetterFakeModelManager(object): + def __init__(self, keys, instance): + self.keys = keys + self.instance = instance + + def get_or_create(self, **kwargs): + defaults = kwargs.pop('defaults', {}) + if kwargs == self.keys: + return self.instance, False + kwargs.update(defaults) + instance = FakeModel.create(**kwargs) + instance.id = 2 + return instance, True + + def values_list(self, *args, **kwargs): + return self + + def order_by(self, *args, **kwargs): + return [1] + + +class BetterFakeModel(object): + @classmethod + def create(cls, **kwargs): + instance = cls(**kwargs) + instance.id = 1 + return instance + + def __init__(self, **kwargs): + for name, value in kwargs.items(): + setattr(self, name, value) + self.id = None + + +class DjangoModelFactoryTestCase(unittest.TestCase): + def test_simple(self): + class FakeModelFactory(factory.DjangoModelFactory): + FACTORY_FOR = FakeModel + + obj = FakeModelFactory(one=1) + self.assertEqual(1, obj.one) + self.assertEqual(2, obj.id) + + def test_existing_instance(self): + prev = BetterFakeModel.create(x=1, y=2, z=3) + prev.id = 42 + + class MyFakeModel(BetterFakeModel): + objects = BetterFakeModelManager({'x': 1}, prev) + + class MyFakeModelFactory(factory.DjangoModelFactory): + FACTORY_FOR = MyFakeModel + FACTORY_DJANGO_GET_OR_CREATE = ('x',) + x = 1 + y = 4 + z = 6 + + obj = MyFakeModelFactory() + self.assertEqual(prev, obj) + self.assertEqual(1, obj.x) + self.assertEqual(2, obj.y) + self.assertEqual(3, obj.z) + self.assertEqual(42, obj.id) + + def test_existing_instance_complex_key(self): + prev = BetterFakeModel.create(x=1, y=2, z=3) + prev.id = 42 + + class MyFakeModel(BetterFakeModel): + objects = BetterFakeModelManager({'x': 1, 'y': 2, 'z': 3}, prev) + + class MyFakeModelFactory(factory.DjangoModelFactory): + FACTORY_FOR = MyFakeModel + FACTORY_DJANGO_GET_OR_CREATE = ('x', 'y', 'z') + x = 1 + y = 4 + z = 6 + + obj = MyFakeModelFactory(y=2, z=3) + self.assertEqual(prev, obj) + self.assertEqual(1, obj.x) + self.assertEqual(2, obj.y) + self.assertEqual(3, obj.z) + self.assertEqual(42, obj.id) + + def test_new_instance(self): + prev = BetterFakeModel.create(x=1, y=2, z=3) + prev.id = 42 + + class MyFakeModel(BetterFakeModel): + objects = BetterFakeModelManager({'x': 1}, prev) + + class MyFakeModelFactory(factory.DjangoModelFactory): + FACTORY_FOR = MyFakeModel + FACTORY_DJANGO_GET_OR_CREATE = ('x',) + x = 1 + y = 4 + z = 6 + + obj = MyFakeModelFactory(x=2) + self.assertNotEqual(prev, obj) + self.assertEqual(2, obj.x) + self.assertEqual(4, obj.y) + self.assertEqual(6, obj.z) + self.assertEqual(2, obj.id) + + def test_new_instance_complex_key(self): + prev = BetterFakeModel.create(x=1, y=2, z=3) + prev.id = 42 + + class MyFakeModel(BetterFakeModel): + objects = BetterFakeModelManager({'x': 1, 'y': 2, 'z': 3}, prev) + + class MyFakeModelFactory(factory.DjangoModelFactory): + FACTORY_FOR = MyFakeModel + FACTORY_DJANGO_GET_OR_CREATE = ('x', 'y', 'z') + x = 1 + y = 4 + z = 6 + + obj = MyFakeModelFactory(y=2, z=4) + self.assertNotEqual(prev, obj) + self.assertEqual(1, obj.x) + self.assertEqual(2, obj.y) + self.assertEqual(4, obj.z) + self.assertEqual(2, obj.id) + + class PostGenerationTestCase(unittest.TestCase): def test_post_generation(self): class TestObjectFactory(factory.Factory): -- cgit v1.2.3