diff options
-rw-r--r-- | docs/introduction.rst | 2 | ||||
-rw-r--r-- | factory/base.py | 200 | ||||
-rw-r--r-- | tests/test_base.py | 22 |
3 files changed, 117 insertions, 107 deletions
diff --git a/docs/introduction.rst b/docs/introduction.rst index d211a83..8bbb10c 100644 --- a/docs/introduction.rst +++ b/docs/introduction.rst @@ -253,6 +253,6 @@ Calling a :class:`~factory.Factory` subclass will provide an object through the <MyClass: X (saved)> -The default strategy can ba changed by setting the class-level :attr:`~factory.Factory.default_strategy` attribute. +The default strategy can ba changed by setting the class-level :attr:`~factory.Factory.FACTROY_STRATEGY` attribute. diff --git a/factory/base.py b/factory/base.py index 25d7a14..aef21d5 100644 --- a/factory/base.py +++ b/factory/base.py @@ -44,11 +44,11 @@ CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class' # Factory metaclasses def get_factory_bases(bases): - """Retrieve all BaseFactoryMetaClass-derived bases from a list.""" - return [b for b in bases if isinstance(b, BaseFactoryMetaClass)] + """Retrieve all FactoryMetaClass-derived bases from a list.""" + return [b for b in bases if issubclass(b, BaseFactory)] -class BaseFactoryMetaClass(type): +class FactoryMetaClass(type): """Factory metaclass for handling ordered declarations.""" def __call__(cls, **kwargs): @@ -57,68 +57,14 @@ class BaseFactoryMetaClass(type): Returns an instance of the associated class. """ - if cls.default_strategy == BUILD_STRATEGY: + if cls.FACTORY_STRATEGY == BUILD_STRATEGY: return cls.build(**kwargs) - elif cls.default_strategy == CREATE_STRATEGY: + elif cls.FACTORY_STRATEGY == CREATE_STRATEGY: return cls.create(**kwargs) - elif cls.default_strategy == STUB_STRATEGY: + elif cls.FACTORY_STRATEGY == STUB_STRATEGY: return cls.stub(**kwargs) else: - raise BaseFactory.UnknownStrategy('Unknown default_strategy: {0}'.format(cls.default_strategy)) - - def __new__(cls, class_name, bases, attrs, extra_attrs=None): - """Record attributes as a pattern for later instance construction. - - This is called when a new Factory subclass is defined; it will collect - attribute declaration from the class definition. - - Args: - class_name (str): the name of the class being created - bases (list of class): the parents of the class being created - attrs (str => obj dict): the attributes as defined in the class - definition - extra_attrs (str => obj dict): extra attributes that should not be - included in the factory defaults, even if public. This - argument is only provided by extensions of this metaclass. - - Returns: - A new class - """ - - parent_factories = get_factory_bases(bases) - if not parent_factories: - # If this isn't a subclass of Factory, don't do anything special. - return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) - - declarations = containers.DeclarationDict() - postgen_declarations = containers.PostGenerationDeclarationDict() - - # Add parent declarations in reverse order. - for base in reversed(parent_factories): - # Import parent PostGenerationDeclaration - postgen_declarations.update_with_public( - getattr(base, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS, {})) - # Import all 'public' attributes (avoid those starting with _) - declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) - - # Import attributes from the class definition - non_postgen_attrs = postgen_declarations.update_with_public(attrs) - # Store protected/private attributes in 'non_factory_attrs'. - non_factory_attrs = declarations.update_with_public(non_postgen_attrs) - - # Store the DeclarationDict in the attributes of the newly created class - non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations - non_factory_attrs[CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS] = postgen_declarations - - # Add extra args if provided. - if extra_attrs: - non_factory_attrs.update(extra_attrs) - - return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, non_factory_attrs) - - -class FactoryMetaClass(BaseFactoryMetaClass): - """Factory metaclass for handling class association and ordered declarations.""" + raise BaseFactory.UnknownStrategy('Unknown FACTORY_STRATEGY: {0}'.format(cls.FACTORY_STRATEGY)) @classmethod def _discover_associated_class(cls, class_name, attrs, inherited=None): @@ -126,8 +72,6 @@ class FactoryMetaClass(BaseFactoryMetaClass): In order, the following tests will be performed: - Lookup the FACTORY_CLASS_DECLARATION attribute - - If the newly created class is named 'FooBarFactory', look for a FooBar - class in its module - If an inherited associated class was provided, use it. Args: @@ -154,46 +98,103 @@ class FactoryMetaClass(BaseFactoryMetaClass): if inherited is not None: return inherited - raise Factory.AssociatedClassError( + raise AssociatedClassError( "Could not determine the class associated with %s. " "Use the FACTORY_FOR attribute to specify an associated class." % class_name) - def __new__(cls, class_name, bases, attrs): - """Determine the associated class based on the factory class name. Record the associated class - for construction of an associated class instance at a later time.""" + @classmethod + def _extract_declarations(cls, bases, attributes): + """Extract declarations from a class definition. - parent_factories = get_factory_bases(bases) - if not parent_factories or attrs.get('ABSTRACT_FACTORY', False): - # If this isn't a subclass of Factory, or specifically declared - # abstract, don't do anything special. - if 'ABSTRACT_FACTORY' in attrs: - attrs.pop('ABSTRACT_FACTORY') + Args: + bases (class list): parent Factory subclasses + attributes (dict): attributes declared in the class definition + Returns: + dict: the original attributes, where declarations have been moved to + _declarations and post-generation declarations to + _postgen_declarations. + """ + declarations = containers.DeclarationDict() + postgen_declarations = containers.PostGenerationDeclarationDict() + + # Add parent declarations in reverse order. + for base in reversed(bases): + # Import parent PostGenerationDeclaration + postgen_declarations.update_with_public( + getattr(base, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS, {})) + # Import all 'public' attributes (avoid those starting with _) + declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) + + # Import attributes from the class definition + attributes = postgen_declarations.update_with_public(attributes) + # Store protected/private attributes in 'non_factory_attrs'. + attributes = declarations.update_with_public(attributes) + + # Store the DeclarationDict in the attributes of the newly created class + attributes[CLASS_ATTRIBUTE_DECLARATIONS] = declarations + attributes[CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS] = postgen_declarations + + return attributes + + def __new__(cls, class_name, bases, attrs, extra_attrs=None): + """Record attributes as a pattern for later instance construction. + + This is called when a new Factory subclass is defined; it will collect + attribute declaration from the class definition. + + Args: + class_name (str): the name of the class being created + bases (list of class): the parents of the class being created + attrs (str => obj dict): the attributes as defined in the class + definition + extra_attrs (str => obj dict): extra attributes that should not be + included in the factory defaults, even if public. This + argument is only provided by extensions of this metaclass. + + Returns: + A new class + """ + parent_factories = get_factory_bases(bases) + if not parent_factories: return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) - base = parent_factories[0] + is_abstract = attrs.pop('ABSTRACT_FACTORY', False) + extra_attrs = {} + + if not is_abstract: + + base = parent_factories[0] - inherited_associated_class = getattr(base, - CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None) - associated_class = cls._discover_associated_class(class_name, attrs, - inherited_associated_class) + inherited_associated_class = getattr(base, + CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None) + associated_class = cls._discover_associated_class(class_name, attrs, + inherited_associated_class) - # If inheriting the factory from a parent, keep a link to it. - # This allows to use the sequence counters from the parents. - if associated_class == inherited_associated_class: - attrs['_base_factory'] = base + # If inheriting the factory from a parent, keep a link to it. + # This allows to use the sequence counters from the parents. + if associated_class == inherited_associated_class: + attrs['_base_factory'] = base - # The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into account - # when parsing the declared attributes of the new class. - extra_attrs = {CLASS_ATTRIBUTE_ASSOCIATED_CLASS: associated_class} + # The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into account + # when parsing the declared attributes of the new class. + extra_attrs = {CLASS_ATTRIBUTE_ASSOCIATED_CLASS: associated_class} - return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attrs, extra_attrs=extra_attrs) + # Extract pre- and post-generation declarations + attributes = cls._extract_declarations(parent_factories, attrs) + + # Add extra args if provided. + if extra_attrs: + attributes.update(extra_attrs) + + return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attributes) def __str__(self): return '<%s for %s>' % (self.__name__, getattr(self, CLASS_ATTRIBUTE_ASSOCIATED_CLASS).__name__) + # Factory base classes class BaseFactory(object): @@ -430,10 +431,8 @@ class BaseFactory(object): return cls.generate_batch(strategy, size, **kwargs) -class StubFactory(BaseFactory): - __metaclass__ = BaseFactoryMetaClass - - default_strategy = STUB_STRATEGY +class AssociatedClassError(RuntimeError): + pass class Factory(BaseFactory): @@ -444,10 +443,8 @@ class Factory(BaseFactory): """ __metaclass__ = FactoryMetaClass - default_strategy = CREATE_STRATEGY - - class AssociatedClassError(RuntimeError): - pass + ABSTRACT_FACTORY = True + FACTORY_STRATEGY = CREATE_STRATEGY @classmethod def _adjust_kwargs(cls, **kwargs): @@ -522,6 +519,19 @@ class Factory(BaseFactory): return cls._generate(True, attrs) +Factory.AssociatedClassError = AssociatedClassError + + +class StubFactory(BaseFactory): + __metaclass__ = FactoryMetaClass + + FACTORY_STRATEGY = STUB_STRATEGY + + FACTORY_FOR = containers.StubObject + +print StubFactory._associated_class + + class DjangoModelFactory(Factory): """Factory for Django models. @@ -643,6 +653,6 @@ def use_strategy(new_strategy): This is an alternative to setting default_strategy in the class definition. """ def wrapped_class(klass): - klass.default_strategy = new_strategy + klass.FACTORY_STRATEGY = new_strategy return klass return wrapped_class diff --git a/tests/test_base.py b/tests/test_base.py index e86eae3..e12c0ae 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -108,13 +108,13 @@ class FactoryTestCase(unittest.TestCase): class FactoryDefaultStrategyTestCase(unittest.TestCase): def setUp(self): - self.default_strategy = base.Factory.default_strategy + self.default_strategy = base.Factory.FACTORY_STRATEGY def tearDown(self): - base.Factory.default_strategy = self.default_strategy + base.Factory.FACTORY_STRATEGY = self.default_strategy def testBuildStrategy(self): - base.Factory.default_strategy = base.BUILD_STRATEGY + base.Factory.FACTORY_STRATEGY = base.BUILD_STRATEGY class TestModelFactory(base.Factory): FACTORY_FOR = TestModel @@ -126,7 +126,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase): self.assertFalse(test_model.id) def testCreateStrategy(self): - # Default default_strategy + # Default FACTORY_STRATEGY class TestModelFactory(FakeModelFactory): FACTORY_FOR = TestModel @@ -138,7 +138,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase): self.assertTrue(test_model.id) def testStubStrategy(self): - base.Factory.default_strategy = base.STUB_STRATEGY + base.Factory.FACTORY_STRATEGY = base.STUB_STRATEGY class TestModelFactory(base.Factory): FACTORY_FOR = TestModel @@ -150,7 +150,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase): self.assertFalse(hasattr(test_model, 'id')) # We should have a plain old object def testUnknownStrategy(self): - base.Factory.default_strategy = 'unknown' + base.Factory.FACTORY_STRATEGY = 'unknown' class TestModelFactory(base.Factory): FACTORY_FOR = TestModel @@ -165,11 +165,11 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase): one = 'one' - TestModelFactory.default_strategy = base.CREATE_STRATEGY + TestModelFactory.FACTORY_STRATEGY = base.CREATE_STRATEGY self.assertRaises(base.StubFactory.UnsupportedStrategy, TestModelFactory) - TestModelFactory.default_strategy = base.BUILD_STRATEGY + TestModelFactory.FACTORY_STRATEGY = base.BUILD_STRATEGY self.assertRaises(base.StubFactory.UnsupportedStrategy, TestModelFactory) def test_change_strategy(self): @@ -179,7 +179,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase): one = 'one' - self.assertEqual(base.CREATE_STRATEGY, TestModelFactory.default_strategy) + self.assertEqual(base.CREATE_STRATEGY, TestModelFactory.FACTORY_STRATEGY) class FactoryCreationTestCase(unittest.TestCase): @@ -193,7 +193,7 @@ class FactoryCreationTestCase(unittest.TestCase): class TestFactory(base.StubFactory): pass - self.assertEqual(TestFactory.default_strategy, base.STUB_STRATEGY) + self.assertEqual(TestFactory.FACTORY_STRATEGY, base.STUB_STRATEGY) def testInheritanceWithStub(self): class TestObjectFactory(base.StubFactory): @@ -204,7 +204,7 @@ class FactoryCreationTestCase(unittest.TestCase): class TestFactory(TestObjectFactory): pass - self.assertEqual(TestFactory.default_strategy, base.STUB_STRATEGY) + self.assertEqual(TestFactory.FACTORY_STRATEGY, base.STUB_STRATEGY) def testCustomCreation(self): class TestModelFactory(FakeModelFactory): |