diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-03-15 01:33:56 +0100 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-03-15 01:33:56 +0100 |
commit | 6e9bf5af909e1e164a294fd5589edc4fada06731 (patch) | |
tree | 7063b11f41d681de1e7498294a0cdda6ef234e14 /factory/base.py | |
parent | d63821daba2002b8c455777748007f7198d3d3bc (diff) | |
download | factory-boy-6e9bf5af909e1e164a294fd5589edc4fada06731.tar factory-boy-6e9bf5af909e1e164a294fd5589edc4fada06731.tar.gz |
Merge BaseFactoryMetaClass into FactoryMetaClass.
Also fix FACTORY_STRATEGY.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
Diffstat (limited to 'factory/base.py')
-rw-r--r-- | factory/base.py | 200 |
1 files changed, 105 insertions, 95 deletions
diff --git a/factory/base.py b/factory/base.py index 25d7a14..aef21d5 100644 --- a/factory/base.py +++ b/factory/base.py @@ -44,11 +44,11 @@ CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class' # Factory metaclasses def get_factory_bases(bases): - """Retrieve all BaseFactoryMetaClass-derived bases from a list.""" - return [b for b in bases if isinstance(b, BaseFactoryMetaClass)] + """Retrieve all FactoryMetaClass-derived bases from a list.""" + return [b for b in bases if issubclass(b, BaseFactory)] -class BaseFactoryMetaClass(type): +class FactoryMetaClass(type): """Factory metaclass for handling ordered declarations.""" def __call__(cls, **kwargs): @@ -57,68 +57,14 @@ class BaseFactoryMetaClass(type): Returns an instance of the associated class. """ - if cls.default_strategy == BUILD_STRATEGY: + if cls.FACTORY_STRATEGY == BUILD_STRATEGY: return cls.build(**kwargs) - elif cls.default_strategy == CREATE_STRATEGY: + elif cls.FACTORY_STRATEGY == CREATE_STRATEGY: return cls.create(**kwargs) - elif cls.default_strategy == STUB_STRATEGY: + elif cls.FACTORY_STRATEGY == STUB_STRATEGY: return cls.stub(**kwargs) else: - raise BaseFactory.UnknownStrategy('Unknown default_strategy: {0}'.format(cls.default_strategy)) - - def __new__(cls, class_name, bases, attrs, extra_attrs=None): - """Record attributes as a pattern for later instance construction. - - This is called when a new Factory subclass is defined; it will collect - attribute declaration from the class definition. - - Args: - class_name (str): the name of the class being created - bases (list of class): the parents of the class being created - attrs (str => obj dict): the attributes as defined in the class - definition - extra_attrs (str => obj dict): extra attributes that should not be - included in the factory defaults, even if public. This - argument is only provided by extensions of this metaclass. - - Returns: - A new class - """ - - parent_factories = get_factory_bases(bases) - if not parent_factories: - # If this isn't a subclass of Factory, don't do anything special. - return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) - - declarations = containers.DeclarationDict() - postgen_declarations = containers.PostGenerationDeclarationDict() - - # Add parent declarations in reverse order. - for base in reversed(parent_factories): - # Import parent PostGenerationDeclaration - postgen_declarations.update_with_public( - getattr(base, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS, {})) - # Import all 'public' attributes (avoid those starting with _) - declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) - - # Import attributes from the class definition - non_postgen_attrs = postgen_declarations.update_with_public(attrs) - # Store protected/private attributes in 'non_factory_attrs'. - non_factory_attrs = declarations.update_with_public(non_postgen_attrs) - - # Store the DeclarationDict in the attributes of the newly created class - non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations - non_factory_attrs[CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS] = postgen_declarations - - # Add extra args if provided. - if extra_attrs: - non_factory_attrs.update(extra_attrs) - - return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, non_factory_attrs) - - -class FactoryMetaClass(BaseFactoryMetaClass): - """Factory metaclass for handling class association and ordered declarations.""" + raise BaseFactory.UnknownStrategy('Unknown FACTORY_STRATEGY: {0}'.format(cls.FACTORY_STRATEGY)) @classmethod def _discover_associated_class(cls, class_name, attrs, inherited=None): @@ -126,8 +72,6 @@ class FactoryMetaClass(BaseFactoryMetaClass): In order, the following tests will be performed: - Lookup the FACTORY_CLASS_DECLARATION attribute - - If the newly created class is named 'FooBarFactory', look for a FooBar - class in its module - If an inherited associated class was provided, use it. Args: @@ -154,46 +98,103 @@ class FactoryMetaClass(BaseFactoryMetaClass): if inherited is not None: return inherited - raise Factory.AssociatedClassError( + raise AssociatedClassError( "Could not determine the class associated with %s. " "Use the FACTORY_FOR attribute to specify an associated class." % class_name) - def __new__(cls, class_name, bases, attrs): - """Determine the associated class based on the factory class name. Record the associated class - for construction of an associated class instance at a later time.""" + @classmethod + def _extract_declarations(cls, bases, attributes): + """Extract declarations from a class definition. - parent_factories = get_factory_bases(bases) - if not parent_factories or attrs.get('ABSTRACT_FACTORY', False): - # If this isn't a subclass of Factory, or specifically declared - # abstract, don't do anything special. - if 'ABSTRACT_FACTORY' in attrs: - attrs.pop('ABSTRACT_FACTORY') + Args: + bases (class list): parent Factory subclasses + attributes (dict): attributes declared in the class definition + Returns: + dict: the original attributes, where declarations have been moved to + _declarations and post-generation declarations to + _postgen_declarations. + """ + declarations = containers.DeclarationDict() + postgen_declarations = containers.PostGenerationDeclarationDict() + + # Add parent declarations in reverse order. + for base in reversed(bases): + # Import parent PostGenerationDeclaration + postgen_declarations.update_with_public( + getattr(base, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS, {})) + # Import all 'public' attributes (avoid those starting with _) + declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) + + # Import attributes from the class definition + attributes = postgen_declarations.update_with_public(attributes) + # Store protected/private attributes in 'non_factory_attrs'. + attributes = declarations.update_with_public(attributes) + + # Store the DeclarationDict in the attributes of the newly created class + attributes[CLASS_ATTRIBUTE_DECLARATIONS] = declarations + attributes[CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS] = postgen_declarations + + return attributes + + def __new__(cls, class_name, bases, attrs, extra_attrs=None): + """Record attributes as a pattern for later instance construction. + + This is called when a new Factory subclass is defined; it will collect + attribute declaration from the class definition. + + Args: + class_name (str): the name of the class being created + bases (list of class): the parents of the class being created + attrs (str => obj dict): the attributes as defined in the class + definition + extra_attrs (str => obj dict): extra attributes that should not be + included in the factory defaults, even if public. This + argument is only provided by extensions of this metaclass. + + Returns: + A new class + """ + parent_factories = get_factory_bases(bases) + if not parent_factories: return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) - base = parent_factories[0] + is_abstract = attrs.pop('ABSTRACT_FACTORY', False) + extra_attrs = {} + + if not is_abstract: + + base = parent_factories[0] - inherited_associated_class = getattr(base, - CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None) - associated_class = cls._discover_associated_class(class_name, attrs, - inherited_associated_class) + inherited_associated_class = getattr(base, + CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None) + associated_class = cls._discover_associated_class(class_name, attrs, + inherited_associated_class) - # If inheriting the factory from a parent, keep a link to it. - # This allows to use the sequence counters from the parents. - if associated_class == inherited_associated_class: - attrs['_base_factory'] = base + # If inheriting the factory from a parent, keep a link to it. + # This allows to use the sequence counters from the parents. + if associated_class == inherited_associated_class: + attrs['_base_factory'] = base - # The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into account - # when parsing the declared attributes of the new class. - extra_attrs = {CLASS_ATTRIBUTE_ASSOCIATED_CLASS: associated_class} + # The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into account + # when parsing the declared attributes of the new class. + extra_attrs = {CLASS_ATTRIBUTE_ASSOCIATED_CLASS: associated_class} - return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attrs, extra_attrs=extra_attrs) + # Extract pre- and post-generation declarations + attributes = cls._extract_declarations(parent_factories, attrs) + + # Add extra args if provided. + if extra_attrs: + attributes.update(extra_attrs) + + return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attributes) def __str__(self): return '<%s for %s>' % (self.__name__, getattr(self, CLASS_ATTRIBUTE_ASSOCIATED_CLASS).__name__) + # Factory base classes class BaseFactory(object): @@ -430,10 +431,8 @@ class BaseFactory(object): return cls.generate_batch(strategy, size, **kwargs) -class StubFactory(BaseFactory): - __metaclass__ = BaseFactoryMetaClass - - default_strategy = STUB_STRATEGY +class AssociatedClassError(RuntimeError): + pass class Factory(BaseFactory): @@ -444,10 +443,8 @@ class Factory(BaseFactory): """ __metaclass__ = FactoryMetaClass - default_strategy = CREATE_STRATEGY - - class AssociatedClassError(RuntimeError): - pass + ABSTRACT_FACTORY = True + FACTORY_STRATEGY = CREATE_STRATEGY @classmethod def _adjust_kwargs(cls, **kwargs): @@ -522,6 +519,19 @@ class Factory(BaseFactory): return cls._generate(True, attrs) +Factory.AssociatedClassError = AssociatedClassError + + +class StubFactory(BaseFactory): + __metaclass__ = FactoryMetaClass + + FACTORY_STRATEGY = STUB_STRATEGY + + FACTORY_FOR = containers.StubObject + +print StubFactory._associated_class + + class DjangoModelFactory(Factory): """Factory for Django models. @@ -643,6 +653,6 @@ def use_strategy(new_strategy): This is an alternative to setting default_strategy in the class definition. """ def wrapped_class(klass): - klass.default_strategy = new_strategy + klass.FACTORY_STRATEGY = new_strategy return klass return wrapped_class |