summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polyconseil.fr>2013-04-16 11:32:36 +0200
committerRaphaël Barrois <raphael.barrois@polyconseil.fr>2013-04-16 11:32:36 +0200
commitef5011d7d28cc7034e07dc75a6661a0253c0fe1d (patch)
tree740e892080ac6faef32b8ab9c1c6792a4185cc17 /tests
parent68b5872e8cbd33f5f59ea8d859e326eb0ff0c6eb (diff)
downloadfactory-boy-ef5011d7d28cc7034e07dc75a6661a0253c0fe1d.tar
factory-boy-ef5011d7d28cc7034e07dc75a6661a0253c0fe1d.tar.gz
Don't use objects.get_or_create() unless required.
Diffstat (limited to 'tests')
-rw-r--r--tests/test_using.py27
1 files changed, 26 insertions, 1 deletions
diff --git a/tests/test_using.py b/tests/test_using.py
index d366c8c..821fad3 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -57,6 +57,12 @@ class FakeModel(object):
instance._defaults = defaults
return instance, True
+ def create(self, **kwargs):
+ instance = FakeModel.create(**kwargs)
+ instance.id = 2
+ instance._defaults = None
+ return instance
+
def values_list(self, *args, **kwargs):
return self
@@ -1787,7 +1793,7 @@ class DjangoModelFactoryTestCase(unittest.TestCase):
a = factory.Sequence(lambda n: 'foo_%s' % n)
o = TestModelFactory()
- self.assertEqual({}, o._defaults)
+ self.assertEqual(None, o._defaults)
self.assertEqual('foo_2', o.a)
self.assertEqual(2, o.id)
@@ -1809,6 +1815,25 @@ class DjangoModelFactoryTestCase(unittest.TestCase):
self.assertEqual(4, o.d)
self.assertEqual(2, o.id)
+ def test_full_get_or_create(self):
+ """Test a DjangoModelFactory with all fields in get_or_create."""
+ class TestModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = TestModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('a', 'b', 'c', 'd')
+
+ a = factory.Sequence(lambda n: 'foo_%s' % n)
+ b = 2
+ c = 3
+ d = 4
+
+ o = TestModelFactory()
+ self.assertEqual({}, o._defaults)
+ self.assertEqual('foo_2', o.a)
+ self.assertEqual(2, o.b)
+ self.assertEqual(3, o.c)
+ self.assertEqual(4, o.d)
+ self.assertEqual(2, o.id)
+
if __name__ == '__main__':
unittest.main()