From 6983f7fcbb9bad0ea825fb3de8682f178fab5647 Mon Sep 17 00:00:00 2001 From: Raphaël Barrois Date: Mon, 22 Aug 2011 15:53:40 +0200 Subject: Allow inheriting from more than one factory. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Raphaël Barrois --- factory/base.py | 31 ++++++++++++++++--------------- factory/test_base.py | 31 ++++++++++++++++++------------- 2 files changed, 34 insertions(+), 28 deletions(-) (limited to 'factory') diff --git a/factory/base.py b/factory/base.py index b17938e..858f20e 100644 --- a/factory/base.py +++ b/factory/base.py @@ -46,12 +46,8 @@ CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class' # Factory metaclasses def get_factory_base(bases): - parents = [b for b in bases if isinstance(b, BaseFactoryMetaClass)] - if not parents: - return None - if len(parents) > 1: - raise RuntimeError('You can only inherit from one Factory') - return parents[0] + return [b for b in bases if isinstance(b, BaseFactoryMetaClass)] + class BaseFactoryMetaClass(type): '''Factory metaclass for handling ordered declarations.''' @@ -72,19 +68,22 @@ class BaseFactoryMetaClass(type): '''Record attributes (unordered declarations) and ordered declarations for construction of an associated class instance at a later time.''' - base = get_factory_base(bases) - if not base or attrs.get('ABSTRACT_FACTORY', False): + parent_factories = get_factory_base(bases) + if not parent_factories or attrs.get('ABSTRACT_FACTORY', False): # If this isn't a subclass of Factory, don't do anything special. return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) - declarations = DeclarationsHolder(defaults=getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) - attrs = declarations.update_base(attrs) + declarations = DeclarationsHolder() + for base in parent_factories: + declarations.update_base(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) + + non_factory_attrs = declarations.update_base(attrs) - attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations + non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations - attrs.update(extra_attrs) + non_factory_attrs.update(extra_attrs) - return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) + return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, non_factory_attrs) class FactoryMetaClass(BaseFactoryMetaClass): '''Factory metaclass for handling class association and ordered declarations.''' @@ -99,11 +98,13 @@ class FactoryMetaClass(BaseFactoryMetaClass): '''Determine the associated class based on the factory class name. Record the associated class for construction of an associated class instance at a later time.''' - base = get_factory_base(bases) - if not base or attrs.get('ABSTRACT_FACTORY', False): + parent_factories = get_factory_base(bases) + if not parent_factories or attrs.get('ABSTRACT_FACTORY', False): # If this isn't a subclass of Factory, don't do anything special. return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) + base = parent_factories[0] + inherited_associated_class = getattr(base, CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None) own_associated_class = None used_auto_discovery = False diff --git a/factory/test_base.py b/factory/test_base.py index c8d464b..8772f8b 100644 --- a/factory/test_base.py +++ b/factory/test_base.py @@ -234,6 +234,24 @@ class FactoryTestCase(unittest.TestCase): ones = set([x.one for x in (parent, alt_parent, sub, alt_sub)]) self.assertEqual(4, len(ones)) + def testDualInheritance(self): + class TestObjectFactory(Factory): + one = 'one' + + class TestOtherFactory(Factory): + FACTORY_FOR = TestObject + two = 'two' + four = 'four' + + class TestFactory(TestObjectFactory, TestOtherFactory): + three = 'three' + + obj = TestFactory.build(two=2) + self.assertEqual('one', obj.one) + self.assertEqual(2, obj.two) + self.assertEqual('three', obj.three) + self.assertEqual('four', obj.four) + def testSetCreationFunction(self): def creation_function(class_to_create, **kwargs): return "This doesn't even return an instance of {0}".format(class_to_create.__name__) @@ -400,19 +418,6 @@ class FactoryCreationTestCase(unittest.TestCase): except Factory.AssociatedClassError as e: self.assertTrue('autodiscovery' not in str(e)) - def testInheritanceFromMoreThanOneFactory(self): - class TestObjectFactory(StubFactory): - pass - - class TestModelFactory(TestObjectFactory): - pass - - try: - class TestFactory(TestObjectFactory, TestModelFactory): - pass - self.fail() - except RuntimeError as e: - self.assertTrue('one Factory' in str(e)) if __name__ == '__main__': unittest.main() -- cgit v1.2.3