diff options
-rw-r--r-- | factory/base.py | 31 | ||||
-rw-r--r-- | factory/test_base.py | 31 |
2 files changed, 34 insertions, 28 deletions
diff --git a/factory/base.py b/factory/base.py index b17938e..858f20e 100644 --- a/factory/base.py +++ b/factory/base.py @@ -46,12 +46,8 @@ CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class' # Factory metaclasses def get_factory_base(bases): - parents = [b for b in bases if isinstance(b, BaseFactoryMetaClass)] - if not parents: - return None - if len(parents) > 1: - raise RuntimeError('You can only inherit from one Factory') - return parents[0] + return [b for b in bases if isinstance(b, BaseFactoryMetaClass)] + class BaseFactoryMetaClass(type): '''Factory metaclass for handling ordered declarations.''' @@ -72,19 +68,22 @@ class BaseFactoryMetaClass(type): '''Record attributes (unordered declarations) and ordered declarations for construction of an associated class instance at a later time.''' - base = get_factory_base(bases) - if not base or attrs.get('ABSTRACT_FACTORY', False): + parent_factories = get_factory_base(bases) + if not parent_factories or attrs.get('ABSTRACT_FACTORY', False): # If this isn't a subclass of Factory, don't do anything special. return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) - declarations = DeclarationsHolder(defaults=getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) - attrs = declarations.update_base(attrs) + declarations = DeclarationsHolder() + for base in parent_factories: + declarations.update_base(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) + + non_factory_attrs = declarations.update_base(attrs) - attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations + non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations - attrs.update(extra_attrs) + non_factory_attrs.update(extra_attrs) - return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) + return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, non_factory_attrs) class FactoryMetaClass(BaseFactoryMetaClass): '''Factory metaclass for handling class association and ordered declarations.''' @@ -99,11 +98,13 @@ class FactoryMetaClass(BaseFactoryMetaClass): '''Determine the associated class based on the factory class name. Record the associated class for construction of an associated class instance at a later time.''' - base = get_factory_base(bases) - if not base or attrs.get('ABSTRACT_FACTORY', False): + parent_factories = get_factory_base(bases) + if not parent_factories or attrs.get('ABSTRACT_FACTORY', False): # If this isn't a subclass of Factory, don't do anything special. return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) + base = parent_factories[0] + inherited_associated_class = getattr(base, CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None) own_associated_class = None used_auto_discovery = False diff --git a/factory/test_base.py b/factory/test_base.py index c8d464b..8772f8b 100644 --- a/factory/test_base.py +++ b/factory/test_base.py @@ -234,6 +234,24 @@ class FactoryTestCase(unittest.TestCase): ones = set([x.one for x in (parent, alt_parent, sub, alt_sub)]) self.assertEqual(4, len(ones)) + def testDualInheritance(self): + class TestObjectFactory(Factory): + one = 'one' + + class TestOtherFactory(Factory): + FACTORY_FOR = TestObject + two = 'two' + four = 'four' + + class TestFactory(TestObjectFactory, TestOtherFactory): + three = 'three' + + obj = TestFactory.build(two=2) + self.assertEqual('one', obj.one) + self.assertEqual(2, obj.two) + self.assertEqual('three', obj.three) + self.assertEqual('four', obj.four) + def testSetCreationFunction(self): def creation_function(class_to_create, **kwargs): return "This doesn't even return an instance of {0}".format(class_to_create.__name__) @@ -400,19 +418,6 @@ class FactoryCreationTestCase(unittest.TestCase): except Factory.AssociatedClassError as e: self.assertTrue('autodiscovery' not in str(e)) - def testInheritanceFromMoreThanOneFactory(self): - class TestObjectFactory(StubFactory): - pass - - class TestModelFactory(TestObjectFactory): - pass - - try: - class TestFactory(TestObjectFactory, TestModelFactory): - pass - self.fail() - except RuntimeError as e: - self.assertTrue('one Factory' in str(e)) if __name__ == '__main__': unittest.main() |