summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2012-08-16 20:28:21 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2012-08-16 22:38:42 +0200
commit6efa57cf38f945c55214a94e0e7c12cc7eff474f (patch)
tree2f43cb1fe7af4d0a4ee40ba54a4603ec74da215c
parentdf0b1124cbb9f244dc40f435410ec16462a8fc9b (diff)
downloadfactory-boy-6efa57cf38f945c55214a94e0e7c12cc7eff474f.tar
factory-boy-6efa57cf38f945c55214a94e0e7c12cc7eff474f.tar.gz
Refactor building_function/creation_function handling.
Rely on inheritance instead of handwritten set_creation_function and such. Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r--factory/base.py65
-rw-r--r--tests/test_base.py24
-rw-r--r--tests/test_using.py106
3 files changed, 136 insertions, 59 deletions
diff --git a/factory/base.py b/factory/base.py
index 8ff6151..159a50c 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -317,6 +317,37 @@ class BaseFactory(object):
return getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS).copy(extra_defs)
@classmethod
+ def _build(cls, target_class, *args, **kwargs):
+ """Actually build an instance of the target_class.
+
+ Customization point, will be called once the full set of args and kwargs
+ has been computed.
+
+ Args:
+ target_class (type): the class for which an instance should be
+ built
+ args (tuple): arguments to use when building the class
+ kwargs (dict): keyword arguments to use when building the class
+ """
+
+ return target_class(*args, **kwargs)
+
+ @classmethod
+ def _create(cls, target_class, *args, **kwargs):
+ """Actually create an instance of the target_class.
+
+ Customization point, will be called once the full set of args and kwargs
+ has been computed.
+
+ Args:
+ target_class (type): the class for which an instance should be
+ created
+ args (tuple): arguments to use when creating the class
+ kwargs (dict): keyword arguments to use when creating the class
+ """
+ return target_class(*args, **kwargs)
+
+ @classmethod
def build(cls, **kwargs):
"""Build an instance of the associated class, with overriden attrs."""
raise cls.UnsupportedStrategy()
@@ -462,7 +493,7 @@ class Factory(BaseFactory):
# Customizing 'create' strategy, using a tuple to keep the creation function
# from turning it into an instance method.
- _creation_function = (DJANGO_CREATION,)
+ _creation_function = (None,)
@classmethod
def set_creation_function(cls, creation_function):
@@ -485,11 +516,22 @@ class Factory(BaseFactory):
an instance will be created, and keyword arguments for the value
of the fields of the instance.
"""
- return cls._creation_function[0]
+ creation_function = cls._creation_function[0]
+ if creation_function:
+ return creation_function
+ elif cls._create.__func__ == Factory._create.__func__:
+ # Backwards compatibility.
+ # Default creation_function and default _create() behavior.
+ # The best "Vanilla" _create detection algorithm I found is relying
+ # on actual method implementation (otherwise, make_factory isn't
+ # detected as 'default').
+ return DJANGO_CREATION
+ else:
+ return creation_function
# Customizing 'build' strategy, using a tuple to keep the creation function
# from turning it into an instance method.
- _building_function = (NAIVE_BUILD,)
+ _building_function = (None,)
@classmethod
def set_building_function(cls, building_function):
@@ -522,10 +564,19 @@ class Factory(BaseFactory):
create: bool, whether to create or to build the object
**kwargs: arguments to pass to the creation function
"""
+ target_class = getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS)
if create:
- return cls.get_creation_function()(getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS), **kwargs)
+ # Backwards compatibility
+ creation_function = cls.get_creation_function()
+ if creation_function:
+ return creation_function(target_class, **kwargs)
+ return cls._create(target_class, **kwargs)
else:
- return cls.get_building_function()(getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS), **kwargs)
+ # Backwards compatibility
+ building_function = cls.get_building_function()
+ if building_function:
+ return building_function(target_class, **kwargs)
+ return cls._build(target_class, **kwargs)
@classmethod
def _generate(cls, create, attrs):
@@ -581,6 +632,10 @@ class DjangoModelFactory(Factory):
except IndexError:
return 1
+ def _create(cls, target_class, *args, **kwargs):
+ """Create an instance of the model, and save it to the database."""
+ return target_class._default_manager.create(*args, **kwargs)
+
def make_factory(klass, **kwargs):
"""Create a new, simple factory for the given class."""
diff --git a/tests/test_base.py b/tests/test_base.py
index 5cbb31b..7ec3d0e 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -35,19 +35,25 @@ class TestObject(object):
self.four = four
class FakeDjangoModel(object):
- class FakeDjangoManager(object):
- def create(self, **kwargs):
- fake_model = FakeDjangoModel(**kwargs)
- fake_model.id = 1
- return fake_model
-
- objects = FakeDjangoManager()
+ @classmethod
+ def create(cls, **kwargs):
+ instance = cls(**kwargs)
+ instance.id = 1
+ return instance
def __init__(self, **kwargs):
for name, value in kwargs.iteritems():
setattr(self, name, value)
self.id = None
+class FakeModelFactory(base.Factory):
+ ABSTRACT_FACTORY = True
+
+ @classmethod
+ def _create(cls, target_class, *args, **kwargs):
+ return target_class.create(**kwargs)
+
+
class TestModel(FakeDjangoModel):
pass
@@ -114,7 +120,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
def testCreateStrategy(self):
# Default default_strategy
- class TestModelFactory(base.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -215,7 +221,7 @@ class FactoryCreationTestCase(unittest.TestCase):
self.assertEqual(TestFactory.default_strategy, base.STUB_STRATEGY)
def testCustomCreation(self):
- class TestModelFactory(base.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
@classmethod
diff --git a/tests/test_using.py b/tests/test_using.py
index 33dd3e3..681daa9 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -36,21 +36,37 @@ class TestObject(object):
self.four = four
self.five = five
-class FakeDjangoModel(object):
- class FakeDjangoManager(object):
+
+class FakeModel(object):
+ @classmethod
+ def create(cls, **kwargs):
+ instance = cls(**kwargs)
+ instance.id = 1
+ return instance
+
+ class FakeModelManager(object):
def create(self, **kwargs):
- fake_model = FakeDjangoModel(**kwargs)
- fake_model.id = 1
- return fake_model
+ instance = FakeModel.create(**kwargs)
+ instance.id = 2
+ return instance
- objects = FakeDjangoManager()
+ objects = FakeModelManager()
def __init__(self, **kwargs):
for name, value in kwargs.iteritems():
setattr(self, name, value)
self.id = None
-class TestModel(FakeDjangoModel):
+
+class FakeModelFactory(factory.Factory):
+ ABSTRACT_FACTORY = True
+
+ @classmethod
+ def _create(cls, target_class, *args, **kwargs):
+ return target_class.create(**kwargs)
+
+
+class TestModel(FakeModel):
pass
@@ -85,18 +101,18 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(obj.four, None)
def test_create(self):
- obj = factory.create(FakeDjangoModel, foo='bar')
- self.assertEqual(obj.id, 1)
+ obj = factory.create(FakeModel, foo='bar')
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_create_batch(self):
- objs = factory.create_batch(FakeDjangoModel, 4, foo='bar')
+ objs = factory.create_batch(FakeModel, 4, foo='bar')
self.assertEqual(4, len(objs))
self.assertEqual(4, len(set(objs)))
for obj in objs:
- self.assertEqual(obj.id, 1)
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_stub(self):
@@ -105,7 +121,7 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertFalse(hasattr(obj, 'two'))
def test_stub_batch(self):
- objs = factory.stub_batch(FakeDjangoModel, 4, foo='bar')
+ objs = factory.stub_batch(FakeModel, 4, foo='bar')
self.assertEqual(4, len(objs))
self.assertEqual(4, len(set(objs)))
@@ -115,22 +131,22 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(obj.foo, 'bar')
def test_generate_build(self):
- obj = factory.generate(FakeDjangoModel, factory.BUILD_STRATEGY, foo='bar')
+ obj = factory.generate(FakeModel, factory.BUILD_STRATEGY, foo='bar')
self.assertEqual(obj.id, None)
self.assertEqual(obj.foo, 'bar')
def test_generate_create(self):
- obj = factory.generate(FakeDjangoModel, factory.CREATE_STRATEGY, foo='bar')
- self.assertEqual(obj.id, 1)
+ obj = factory.generate(FakeModel, factory.CREATE_STRATEGY, foo='bar')
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_generate_stub(self):
- obj = factory.generate(FakeDjangoModel, factory.STUB_STRATEGY, foo='bar')
+ obj = factory.generate(FakeModel, factory.STUB_STRATEGY, foo='bar')
self.assertFalse(hasattr(obj, 'id'))
self.assertEqual(obj.foo, 'bar')
def test_generate_batch_build(self):
- objs = factory.generate_batch(FakeDjangoModel, factory.BUILD_STRATEGY, 20, foo='bar')
+ objs = factory.generate_batch(FakeModel, factory.BUILD_STRATEGY, 20, foo='bar')
self.assertEqual(20, len(objs))
self.assertEqual(20, len(set(objs)))
@@ -140,17 +156,17 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(obj.foo, 'bar')
def test_generate_batch_create(self):
- objs = factory.generate_batch(FakeDjangoModel, factory.CREATE_STRATEGY, 20, foo='bar')
+ objs = factory.generate_batch(FakeModel, factory.CREATE_STRATEGY, 20, foo='bar')
self.assertEqual(20, len(objs))
self.assertEqual(20, len(set(objs)))
for obj in objs:
- self.assertEqual(obj.id, 1)
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_generate_batch_stub(self):
- objs = factory.generate_batch(FakeDjangoModel, factory.STUB_STRATEGY, 20, foo='bar')
+ objs = factory.generate_batch(FakeModel, factory.STUB_STRATEGY, 20, foo='bar')
self.assertEqual(20, len(objs))
self.assertEqual(20, len(set(objs)))
@@ -160,17 +176,17 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(obj.foo, 'bar')
def test_simple_generate_build(self):
- obj = factory.simple_generate(FakeDjangoModel, False, foo='bar')
+ obj = factory.simple_generate(FakeModel, False, foo='bar')
self.assertEqual(obj.id, None)
self.assertEqual(obj.foo, 'bar')
def test_simple_generate_create(self):
- obj = factory.simple_generate(FakeDjangoModel, True, foo='bar')
- self.assertEqual(obj.id, 1)
+ obj = factory.simple_generate(FakeModel, True, foo='bar')
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_simple_generate_batch_build(self):
- objs = factory.simple_generate_batch(FakeDjangoModel, False, 20, foo='bar')
+ objs = factory.simple_generate_batch(FakeModel, False, 20, foo='bar')
self.assertEqual(20, len(objs))
self.assertEqual(20, len(set(objs)))
@@ -180,13 +196,13 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(obj.foo, 'bar')
def test_simple_generate_batch_create(self):
- objs = factory.simple_generate_batch(FakeDjangoModel, True, 20, foo='bar')
+ objs = factory.simple_generate_batch(FakeModel, True, 20, foo='bar')
self.assertEqual(20, len(objs))
self.assertEqual(20, len(set(objs)))
for obj in objs:
- self.assertEqual(obj.id, 1)
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_make_factory(self):
@@ -374,7 +390,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(test_object1.two, 'two1')
def testCreate(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -384,7 +400,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertTrue(test_model.id)
def test_create_batch(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -400,7 +416,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertTrue(obj.id)
def test_generate_build(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -410,7 +426,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertFalse(test_model.id)
def test_generate_create(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -420,7 +436,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertTrue(test_model.id)
def test_generate_stub(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -430,7 +446,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertFalse(hasattr(test_model, 'id'))
def test_generate_batch_build(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -446,7 +462,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertFalse(obj.id)
def test_generate_batch_create(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -462,7 +478,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertTrue(obj.id)
def test_generate_batch_stub(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -478,7 +494,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertFalse(hasattr(obj, 'id'))
def test_simple_generate_build(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -488,7 +504,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertFalse(test_model.id)
def test_simple_generate_create(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -498,7 +514,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertTrue(test_model.id)
def test_simple_generate_batch_build(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -514,7 +530,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertFalse(obj.id)
def test_simple_generate_batch_create(self):
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 'one'
@@ -611,7 +627,7 @@ class UsingFactoryTestCase(unittest.TestCase):
def creation_function(class_to_create, **kwargs):
return "This doesn't even return an instance of {0}".format(class_to_create.__name__)
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
TestModelFactory.set_creation_function(creation_function)
@@ -642,14 +658,14 @@ class UsingFactoryTestCase(unittest.TestCase):
class SubFactoryTestCase(unittest.TestCase):
def testSubFactory(self):
- class TestModel2(FakeDjangoModel):
+ class TestModel2(FakeModel):
pass
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
one = 3
- class TestModel2Factory(factory.Factory):
+ class TestModel2Factory(FakeModelFactory):
FACTORY_FOR = TestModel2
two = factory.SubFactory(TestModelFactory, one=1)
@@ -659,13 +675,13 @@ class SubFactoryTestCase(unittest.TestCase):
self.assertEqual(1, test_model.two.id)
def testSubFactoryWithLazyFields(self):
- class TestModel2(FakeDjangoModel):
+ class TestModel2(FakeModel):
pass
- class TestModelFactory(factory.Factory):
+ class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
- class TestModel2Factory(factory.Factory):
+ class TestModel2Factory(FakeModelFactory):
FACTORY_FOR = TestModel2
two = factory.SubFactory(TestModelFactory,
one=factory.Sequence(lambda n: 'x%sx' % n),