diff options
-rw-r--r-- | factory/base.py | 65 | ||||
-rw-r--r-- | tests/test_base.py | 24 | ||||
-rw-r--r-- | tests/test_using.py | 106 |
3 files changed, 136 insertions, 59 deletions
diff --git a/factory/base.py b/factory/base.py index 8ff6151..159a50c 100644 --- a/factory/base.py +++ b/factory/base.py @@ -317,6 +317,37 @@ class BaseFactory(object): return getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS).copy(extra_defs) @classmethod + def _build(cls, target_class, *args, **kwargs): + """Actually build an instance of the target_class. + + Customization point, will be called once the full set of args and kwargs + has been computed. + + Args: + target_class (type): the class for which an instance should be + built + args (tuple): arguments to use when building the class + kwargs (dict): keyword arguments to use when building the class + """ + + return target_class(*args, **kwargs) + + @classmethod + def _create(cls, target_class, *args, **kwargs): + """Actually create an instance of the target_class. + + Customization point, will be called once the full set of args and kwargs + has been computed. + + Args: + target_class (type): the class for which an instance should be + created + args (tuple): arguments to use when creating the class + kwargs (dict): keyword arguments to use when creating the class + """ + return target_class(*args, **kwargs) + + @classmethod def build(cls, **kwargs): """Build an instance of the associated class, with overriden attrs.""" raise cls.UnsupportedStrategy() @@ -462,7 +493,7 @@ class Factory(BaseFactory): # Customizing 'create' strategy, using a tuple to keep the creation function # from turning it into an instance method. - _creation_function = (DJANGO_CREATION,) + _creation_function = (None,) @classmethod def set_creation_function(cls, creation_function): @@ -485,11 +516,22 @@ class Factory(BaseFactory): an instance will be created, and keyword arguments for the value of the fields of the instance. """ - return cls._creation_function[0] + creation_function = cls._creation_function[0] + if creation_function: + return creation_function + elif cls._create.__func__ == Factory._create.__func__: + # Backwards compatibility. + # Default creation_function and default _create() behavior. + # The best "Vanilla" _create detection algorithm I found is relying + # on actual method implementation (otherwise, make_factory isn't + # detected as 'default'). + return DJANGO_CREATION + else: + return creation_function # Customizing 'build' strategy, using a tuple to keep the creation function # from turning it into an instance method. - _building_function = (NAIVE_BUILD,) + _building_function = (None,) @classmethod def set_building_function(cls, building_function): @@ -522,10 +564,19 @@ class Factory(BaseFactory): create: bool, whether to create or to build the object **kwargs: arguments to pass to the creation function """ + target_class = getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS) if create: - return cls.get_creation_function()(getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS), **kwargs) + # Backwards compatibility + creation_function = cls.get_creation_function() + if creation_function: + return creation_function(target_class, **kwargs) + return cls._create(target_class, **kwargs) else: - return cls.get_building_function()(getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS), **kwargs) + # Backwards compatibility + building_function = cls.get_building_function() + if building_function: + return building_function(target_class, **kwargs) + return cls._build(target_class, **kwargs) @classmethod def _generate(cls, create, attrs): @@ -581,6 +632,10 @@ class DjangoModelFactory(Factory): except IndexError: return 1 + def _create(cls, target_class, *args, **kwargs): + """Create an instance of the model, and save it to the database.""" + return target_class._default_manager.create(*args, **kwargs) + def make_factory(klass, **kwargs): """Create a new, simple factory for the given class.""" diff --git a/tests/test_base.py b/tests/test_base.py index 5cbb31b..7ec3d0e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -35,19 +35,25 @@ class TestObject(object): self.four = four class FakeDjangoModel(object): - class FakeDjangoManager(object): - def create(self, **kwargs): - fake_model = FakeDjangoModel(**kwargs) - fake_model.id = 1 - return fake_model - - objects = FakeDjangoManager() + @classmethod + def create(cls, **kwargs): + instance = cls(**kwargs) + instance.id = 1 + return instance def __init__(self, **kwargs): for name, value in kwargs.iteritems(): setattr(self, name, value) self.id = None +class FakeModelFactory(base.Factory): + ABSTRACT_FACTORY = True + + @classmethod + def _create(cls, target_class, *args, **kwargs): + return target_class.create(**kwargs) + + class TestModel(FakeDjangoModel): pass @@ -114,7 +120,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase): def testCreateStrategy(self): # Default default_strategy - class TestModelFactory(base.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -215,7 +221,7 @@ class FactoryCreationTestCase(unittest.TestCase): self.assertEqual(TestFactory.default_strategy, base.STUB_STRATEGY) def testCustomCreation(self): - class TestModelFactory(base.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel @classmethod diff --git a/tests/test_using.py b/tests/test_using.py index 33dd3e3..681daa9 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -36,21 +36,37 @@ class TestObject(object): self.four = four self.five = five -class FakeDjangoModel(object): - class FakeDjangoManager(object): + +class FakeModel(object): + @classmethod + def create(cls, **kwargs): + instance = cls(**kwargs) + instance.id = 1 + return instance + + class FakeModelManager(object): def create(self, **kwargs): - fake_model = FakeDjangoModel(**kwargs) - fake_model.id = 1 - return fake_model + instance = FakeModel.create(**kwargs) + instance.id = 2 + return instance - objects = FakeDjangoManager() + objects = FakeModelManager() def __init__(self, **kwargs): for name, value in kwargs.iteritems(): setattr(self, name, value) self.id = None -class TestModel(FakeDjangoModel): + +class FakeModelFactory(factory.Factory): + ABSTRACT_FACTORY = True + + @classmethod + def _create(cls, target_class, *args, **kwargs): + return target_class.create(**kwargs) + + +class TestModel(FakeModel): pass @@ -85,18 +101,18 @@ class SimpleBuildTestCase(unittest.TestCase): self.assertEqual(obj.four, None) def test_create(self): - obj = factory.create(FakeDjangoModel, foo='bar') - self.assertEqual(obj.id, 1) + obj = factory.create(FakeModel, foo='bar') + self.assertEqual(obj.id, 2) self.assertEqual(obj.foo, 'bar') def test_create_batch(self): - objs = factory.create_batch(FakeDjangoModel, 4, foo='bar') + objs = factory.create_batch(FakeModel, 4, foo='bar') self.assertEqual(4, len(objs)) self.assertEqual(4, len(set(objs))) for obj in objs: - self.assertEqual(obj.id, 1) + self.assertEqual(obj.id, 2) self.assertEqual(obj.foo, 'bar') def test_stub(self): @@ -105,7 +121,7 @@ class SimpleBuildTestCase(unittest.TestCase): self.assertFalse(hasattr(obj, 'two')) def test_stub_batch(self): - objs = factory.stub_batch(FakeDjangoModel, 4, foo='bar') + objs = factory.stub_batch(FakeModel, 4, foo='bar') self.assertEqual(4, len(objs)) self.assertEqual(4, len(set(objs))) @@ -115,22 +131,22 @@ class SimpleBuildTestCase(unittest.TestCase): self.assertEqual(obj.foo, 'bar') def test_generate_build(self): - obj = factory.generate(FakeDjangoModel, factory.BUILD_STRATEGY, foo='bar') + obj = factory.generate(FakeModel, factory.BUILD_STRATEGY, foo='bar') self.assertEqual(obj.id, None) self.assertEqual(obj.foo, 'bar') def test_generate_create(self): - obj = factory.generate(FakeDjangoModel, factory.CREATE_STRATEGY, foo='bar') - self.assertEqual(obj.id, 1) + obj = factory.generate(FakeModel, factory.CREATE_STRATEGY, foo='bar') + self.assertEqual(obj.id, 2) self.assertEqual(obj.foo, 'bar') def test_generate_stub(self): - obj = factory.generate(FakeDjangoModel, factory.STUB_STRATEGY, foo='bar') + obj = factory.generate(FakeModel, factory.STUB_STRATEGY, foo='bar') self.assertFalse(hasattr(obj, 'id')) self.assertEqual(obj.foo, 'bar') def test_generate_batch_build(self): - objs = factory.generate_batch(FakeDjangoModel, factory.BUILD_STRATEGY, 20, foo='bar') + objs = factory.generate_batch(FakeModel, factory.BUILD_STRATEGY, 20, foo='bar') self.assertEqual(20, len(objs)) self.assertEqual(20, len(set(objs))) @@ -140,17 +156,17 @@ class SimpleBuildTestCase(unittest.TestCase): self.assertEqual(obj.foo, 'bar') def test_generate_batch_create(self): - objs = factory.generate_batch(FakeDjangoModel, factory.CREATE_STRATEGY, 20, foo='bar') + objs = factory.generate_batch(FakeModel, factory.CREATE_STRATEGY, 20, foo='bar') self.assertEqual(20, len(objs)) self.assertEqual(20, len(set(objs))) for obj in objs: - self.assertEqual(obj.id, 1) + self.assertEqual(obj.id, 2) self.assertEqual(obj.foo, 'bar') def test_generate_batch_stub(self): - objs = factory.generate_batch(FakeDjangoModel, factory.STUB_STRATEGY, 20, foo='bar') + objs = factory.generate_batch(FakeModel, factory.STUB_STRATEGY, 20, foo='bar') self.assertEqual(20, len(objs)) self.assertEqual(20, len(set(objs))) @@ -160,17 +176,17 @@ class SimpleBuildTestCase(unittest.TestCase): self.assertEqual(obj.foo, 'bar') def test_simple_generate_build(self): - obj = factory.simple_generate(FakeDjangoModel, False, foo='bar') + obj = factory.simple_generate(FakeModel, False, foo='bar') self.assertEqual(obj.id, None) self.assertEqual(obj.foo, 'bar') def test_simple_generate_create(self): - obj = factory.simple_generate(FakeDjangoModel, True, foo='bar') - self.assertEqual(obj.id, 1) + obj = factory.simple_generate(FakeModel, True, foo='bar') + self.assertEqual(obj.id, 2) self.assertEqual(obj.foo, 'bar') def test_simple_generate_batch_build(self): - objs = factory.simple_generate_batch(FakeDjangoModel, False, 20, foo='bar') + objs = factory.simple_generate_batch(FakeModel, False, 20, foo='bar') self.assertEqual(20, len(objs)) self.assertEqual(20, len(set(objs))) @@ -180,13 +196,13 @@ class SimpleBuildTestCase(unittest.TestCase): self.assertEqual(obj.foo, 'bar') def test_simple_generate_batch_create(self): - objs = factory.simple_generate_batch(FakeDjangoModel, True, 20, foo='bar') + objs = factory.simple_generate_batch(FakeModel, True, 20, foo='bar') self.assertEqual(20, len(objs)) self.assertEqual(20, len(set(objs))) for obj in objs: - self.assertEqual(obj.id, 1) + self.assertEqual(obj.id, 2) self.assertEqual(obj.foo, 'bar') def test_make_factory(self): @@ -374,7 +390,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertEqual(test_object1.two, 'two1') def testCreate(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -384,7 +400,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertTrue(test_model.id) def test_create_batch(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -400,7 +416,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertTrue(obj.id) def test_generate_build(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -410,7 +426,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertFalse(test_model.id) def test_generate_create(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -420,7 +436,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertTrue(test_model.id) def test_generate_stub(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -430,7 +446,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertFalse(hasattr(test_model, 'id')) def test_generate_batch_build(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -446,7 +462,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertFalse(obj.id) def test_generate_batch_create(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -462,7 +478,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertTrue(obj.id) def test_generate_batch_stub(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -478,7 +494,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertFalse(hasattr(obj, 'id')) def test_simple_generate_build(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -488,7 +504,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertFalse(test_model.id) def test_simple_generate_create(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -498,7 +514,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertTrue(test_model.id) def test_simple_generate_batch_build(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -514,7 +530,7 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertFalse(obj.id) def test_simple_generate_batch_create(self): - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 'one' @@ -611,7 +627,7 @@ class UsingFactoryTestCase(unittest.TestCase): def creation_function(class_to_create, **kwargs): return "This doesn't even return an instance of {0}".format(class_to_create.__name__) - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel TestModelFactory.set_creation_function(creation_function) @@ -642,14 +658,14 @@ class UsingFactoryTestCase(unittest.TestCase): class SubFactoryTestCase(unittest.TestCase): def testSubFactory(self): - class TestModel2(FakeDjangoModel): + class TestModel2(FakeModel): pass - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel one = 3 - class TestModel2Factory(factory.Factory): + class TestModel2Factory(FakeModelFactory): FACTORY_FOR = TestModel2 two = factory.SubFactory(TestModelFactory, one=1) @@ -659,13 +675,13 @@ class SubFactoryTestCase(unittest.TestCase): self.assertEqual(1, test_model.two.id) def testSubFactoryWithLazyFields(self): - class TestModel2(FakeDjangoModel): + class TestModel2(FakeModel): pass - class TestModelFactory(factory.Factory): + class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel - class TestModel2Factory(factory.Factory): + class TestModel2Factory(FakeModelFactory): FACTORY_FOR = TestModel2 two = factory.SubFactory(TestModelFactory, one=factory.Sequence(lambda n: 'x%sx' % n), |