summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/base.py31
-rw-r--r--factory/test_base.py31
2 files changed, 34 insertions, 28 deletions
diff --git a/factory/base.py b/factory/base.py
index b17938e..858f20e 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -46,12 +46,8 @@ CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class'
# Factory metaclasses
def get_factory_base(bases):
- parents = [b for b in bases if isinstance(b, BaseFactoryMetaClass)]
- if not parents:
- return None
- if len(parents) > 1:
- raise RuntimeError('You can only inherit from one Factory')
- return parents[0]
+ return [b for b in bases if isinstance(b, BaseFactoryMetaClass)]
+
class BaseFactoryMetaClass(type):
'''Factory metaclass for handling ordered declarations.'''
@@ -72,19 +68,22 @@ class BaseFactoryMetaClass(type):
'''Record attributes (unordered declarations) and ordered declarations for construction of
an associated class instance at a later time.'''
- base = get_factory_base(bases)
- if not base or attrs.get('ABSTRACT_FACTORY', False):
+ parent_factories = get_factory_base(bases)
+ if not parent_factories or attrs.get('ABSTRACT_FACTORY', False):
# If this isn't a subclass of Factory, don't do anything special.
return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs)
- declarations = DeclarationsHolder(defaults=getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {}))
- attrs = declarations.update_base(attrs)
+ declarations = DeclarationsHolder()
+ for base in parent_factories:
+ declarations.update_base(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {}))
+
+ non_factory_attrs = declarations.update_base(attrs)
- attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations
+ non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations
- attrs.update(extra_attrs)
+ non_factory_attrs.update(extra_attrs)
- return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs)
+ return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, non_factory_attrs)
class FactoryMetaClass(BaseFactoryMetaClass):
'''Factory metaclass for handling class association and ordered declarations.'''
@@ -99,11 +98,13 @@ class FactoryMetaClass(BaseFactoryMetaClass):
'''Determine the associated class based on the factory class name. Record the associated class
for construction of an associated class instance at a later time.'''
- base = get_factory_base(bases)
- if not base or attrs.get('ABSTRACT_FACTORY', False):
+ parent_factories = get_factory_base(bases)
+ if not parent_factories or attrs.get('ABSTRACT_FACTORY', False):
# If this isn't a subclass of Factory, don't do anything special.
return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attrs)
+ base = parent_factories[0]
+
inherited_associated_class = getattr(base, CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None)
own_associated_class = None
used_auto_discovery = False
diff --git a/factory/test_base.py b/factory/test_base.py
index c8d464b..8772f8b 100644
--- a/factory/test_base.py
+++ b/factory/test_base.py
@@ -234,6 +234,24 @@ class FactoryTestCase(unittest.TestCase):
ones = set([x.one for x in (parent, alt_parent, sub, alt_sub)])
self.assertEqual(4, len(ones))
+ def testDualInheritance(self):
+ class TestObjectFactory(Factory):
+ one = 'one'
+
+ class TestOtherFactory(Factory):
+ FACTORY_FOR = TestObject
+ two = 'two'
+ four = 'four'
+
+ class TestFactory(TestObjectFactory, TestOtherFactory):
+ three = 'three'
+
+ obj = TestFactory.build(two=2)
+ self.assertEqual('one', obj.one)
+ self.assertEqual(2, obj.two)
+ self.assertEqual('three', obj.three)
+ self.assertEqual('four', obj.four)
+
def testSetCreationFunction(self):
def creation_function(class_to_create, **kwargs):
return "This doesn't even return an instance of {0}".format(class_to_create.__name__)
@@ -400,19 +418,6 @@ class FactoryCreationTestCase(unittest.TestCase):
except Factory.AssociatedClassError as e:
self.assertTrue('autodiscovery' not in str(e))
- def testInheritanceFromMoreThanOneFactory(self):
- class TestObjectFactory(StubFactory):
- pass
-
- class TestModelFactory(TestObjectFactory):
- pass
-
- try:
- class TestFactory(TestObjectFactory, TestModelFactory):
- pass
- self.fail()
- except RuntimeError as e:
- self.assertTrue('one Factory' in str(e))
if __name__ == '__main__':
unittest.main()