diff options
-rw-r--r-- | factory/base.py | 25 | ||||
-rw-r--r-- | tests/test_using.py | 27 |
2 files changed, 42 insertions, 10 deletions
diff --git a/factory/base.py b/factory/base.py index 0d03838..2ff2944 100644 --- a/factory/base.py +++ b/factory/base.py @@ -587,8 +587,8 @@ class DjangoModelFactory(Factory): return 1 @classmethod - def _create(cls, target_class, *args, **kwargs): - """Create an instance of the model, and save it to the database.""" + def _get_or_create(cls, target_class, *args, **kwargs): + """Create an instance of the model through objects.get_or_create.""" manager = cls._get_manager(target_class) assert 'defaults' not in cls.FACTORY_DJANGO_GET_OR_CREATE, ( @@ -596,18 +596,25 @@ class DjangoModelFactory(Factory): "(in %s.FACTORY_DJANGO_GET_OR_CREATE=%r)" % (cls, cls.FACTORY_DJANGO_GET_OR_CREATE)) - if cls.FACTORY_DJANGO_GET_OR_CREATE: - key_fields = {} - for field in cls.FACTORY_DJANGO_GET_OR_CREATE: - key_fields[field] = kwargs.pop(field) - key_fields['defaults'] = kwargs - else: - key_fields = kwargs + key_fields = {} + for field in cls.FACTORY_DJANGO_GET_OR_CREATE: + key_fields[field] = kwargs.pop(field) + key_fields['defaults'] = kwargs obj, _created = manager.get_or_create(*args, **key_fields) return obj @classmethod + def _create(cls, target_class, *args, **kwargs): + """Create an instance of the model, and save it to the database.""" + manager = cls._get_manager(target_class) + + if cls.FACTORY_DJANGO_GET_OR_CREATE: + return cls._get_or_create(target_class, *args, **kwargs) + + return manager.create(*args, **kwargs) + + @classmethod def _after_postgeneration(cls, obj, create, results=None): """Save again the instance if creating and at least one hook ran.""" if create and results: diff --git a/tests/test_using.py b/tests/test_using.py index d366c8c..821fad3 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -57,6 +57,12 @@ class FakeModel(object): instance._defaults = defaults return instance, True + def create(self, **kwargs): + instance = FakeModel.create(**kwargs) + instance.id = 2 + instance._defaults = None + return instance + def values_list(self, *args, **kwargs): return self @@ -1787,7 +1793,7 @@ class DjangoModelFactoryTestCase(unittest.TestCase): a = factory.Sequence(lambda n: 'foo_%s' % n) o = TestModelFactory() - self.assertEqual({}, o._defaults) + self.assertEqual(None, o._defaults) self.assertEqual('foo_2', o.a) self.assertEqual(2, o.id) @@ -1809,6 +1815,25 @@ class DjangoModelFactoryTestCase(unittest.TestCase): self.assertEqual(4, o.d) self.assertEqual(2, o.id) + def test_full_get_or_create(self): + """Test a DjangoModelFactory with all fields in get_or_create.""" + class TestModelFactory(factory.DjangoModelFactory): + FACTORY_FOR = TestModel + FACTORY_DJANGO_GET_OR_CREATE = ('a', 'b', 'c', 'd') + + a = factory.Sequence(lambda n: 'foo_%s' % n) + b = 2 + c = 3 + d = 4 + + o = TestModelFactory() + self.assertEqual({}, o._defaults) + self.assertEqual('foo_2', o.a) + self.assertEqual(2, o.b) + self.assertEqual(3, o.c) + self.assertEqual(4, o.d) + self.assertEqual(2, o.id) + if __name__ == '__main__': unittest.main() |