summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2013-03-15 01:33:56 +0100
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2013-03-15 01:33:56 +0100
commit6e9bf5af909e1e164a294fd5589edc4fada06731 (patch)
tree7063b11f41d681de1e7498294a0cdda6ef234e14
parentd63821daba2002b8c455777748007f7198d3d3bc (diff)
downloadfactory-boy-6e9bf5af909e1e164a294fd5589edc4fada06731.tar
factory-boy-6e9bf5af909e1e164a294fd5589edc4fada06731.tar.gz
Merge BaseFactoryMetaClass into FactoryMetaClass.
Also fix FACTORY_STRATEGY. Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r--docs/introduction.rst2
-rw-r--r--factory/base.py200
-rw-r--r--tests/test_base.py22
3 files changed, 117 insertions, 107 deletions
diff --git a/docs/introduction.rst b/docs/introduction.rst
index d211a83..8bbb10c 100644
--- a/docs/introduction.rst
+++ b/docs/introduction.rst
@@ -253,6 +253,6 @@ Calling a :class:`~factory.Factory` subclass will provide an object through the
<MyClass: X (saved)>
-The default strategy can ba changed by setting the class-level :attr:`~factory.Factory.default_strategy` attribute.
+The default strategy can ba changed by setting the class-level :attr:`~factory.Factory.FACTROY_STRATEGY` attribute.
diff --git a/factory/base.py b/factory/base.py
index 25d7a14..aef21d5 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -44,11 +44,11 @@ CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class'
# Factory metaclasses
def get_factory_bases(bases):
- """Retrieve all BaseFactoryMetaClass-derived bases from a list."""
- return [b for b in bases if isinstance(b, BaseFactoryMetaClass)]
+ """Retrieve all FactoryMetaClass-derived bases from a list."""
+ return [b for b in bases if issubclass(b, BaseFactory)]
-class BaseFactoryMetaClass(type):
+class FactoryMetaClass(type):
"""Factory metaclass for handling ordered declarations."""
def __call__(cls, **kwargs):
@@ -57,68 +57,14 @@ class BaseFactoryMetaClass(type):
Returns an instance of the associated class.
"""
- if cls.default_strategy == BUILD_STRATEGY:
+ if cls.FACTORY_STRATEGY == BUILD_STRATEGY:
return cls.build(**kwargs)
- elif cls.default_strategy == CREATE_STRATEGY:
+ elif cls.FACTORY_STRATEGY == CREATE_STRATEGY:
return cls.create(**kwargs)
- elif cls.default_strategy == STUB_STRATEGY:
+ elif cls.FACTORY_STRATEGY == STUB_STRATEGY:
return cls.stub(**kwargs)
else:
- raise BaseFactory.UnknownStrategy('Unknown default_strategy: {0}'.format(cls.default_strategy))
-
- def __new__(cls, class_name, bases, attrs, extra_attrs=None):
- """Record attributes as a pattern for later instance construction.
-
- This is called when a new Factory subclass is defined; it will collect
- attribute declaration from the class definition.
-
- Args:
- class_name (str): the name of the class being created
- bases (list of class): the parents of the class being created
- attrs (str => obj dict): the attributes as defined in the class
- definition
- extra_attrs (str => obj dict): extra attributes that should not be
- included in the factory defaults, even if public. This
- argument is only provided by extensions of this metaclass.
-
- Returns:
- A new class
- """
-
- parent_factories = get_factory_bases(bases)
- if not parent_factories:
- # If this isn't a subclass of Factory, don't do anything special.
- return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs)
-
- declarations = containers.DeclarationDict()
- postgen_declarations = containers.PostGenerationDeclarationDict()
-
- # Add parent declarations in reverse order.
- for base in reversed(parent_factories):
- # Import parent PostGenerationDeclaration
- postgen_declarations.update_with_public(
- getattr(base, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS, {}))
- # Import all 'public' attributes (avoid those starting with _)
- declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {}))
-
- # Import attributes from the class definition
- non_postgen_attrs = postgen_declarations.update_with_public(attrs)
- # Store protected/private attributes in 'non_factory_attrs'.
- non_factory_attrs = declarations.update_with_public(non_postgen_attrs)
-
- # Store the DeclarationDict in the attributes of the newly created class
- non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations
- non_factory_attrs[CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS] = postgen_declarations
-
- # Add extra args if provided.
- if extra_attrs:
- non_factory_attrs.update(extra_attrs)
-
- return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, non_factory_attrs)
-
-
-class FactoryMetaClass(BaseFactoryMetaClass):
- """Factory metaclass for handling class association and ordered declarations."""
+ raise BaseFactory.UnknownStrategy('Unknown FACTORY_STRATEGY: {0}'.format(cls.FACTORY_STRATEGY))
@classmethod
def _discover_associated_class(cls, class_name, attrs, inherited=None):
@@ -126,8 +72,6 @@ class FactoryMetaClass(BaseFactoryMetaClass):
In order, the following tests will be performed:
- Lookup the FACTORY_CLASS_DECLARATION attribute
- - If the newly created class is named 'FooBarFactory', look for a FooBar
- class in its module
- If an inherited associated class was provided, use it.
Args:
@@ -154,46 +98,103 @@ class FactoryMetaClass(BaseFactoryMetaClass):
if inherited is not None:
return inherited
- raise Factory.AssociatedClassError(
+ raise AssociatedClassError(
"Could not determine the class associated with %s. "
"Use the FACTORY_FOR attribute to specify an associated class." %
class_name)
- def __new__(cls, class_name, bases, attrs):
- """Determine the associated class based on the factory class name. Record the associated class
- for construction of an associated class instance at a later time."""
+ @classmethod
+ def _extract_declarations(cls, bases, attributes):
+ """Extract declarations from a class definition.
- parent_factories = get_factory_bases(bases)
- if not parent_factories or attrs.get('ABSTRACT_FACTORY', False):
- # If this isn't a subclass of Factory, or specifically declared
- # abstract, don't do anything special.
- if 'ABSTRACT_FACTORY' in attrs:
- attrs.pop('ABSTRACT_FACTORY')
+ Args:
+ bases (class list): parent Factory subclasses
+ attributes (dict): attributes declared in the class definition
+ Returns:
+ dict: the original attributes, where declarations have been moved to
+ _declarations and post-generation declarations to
+ _postgen_declarations.
+ """
+ declarations = containers.DeclarationDict()
+ postgen_declarations = containers.PostGenerationDeclarationDict()
+
+ # Add parent declarations in reverse order.
+ for base in reversed(bases):
+ # Import parent PostGenerationDeclaration
+ postgen_declarations.update_with_public(
+ getattr(base, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS, {}))
+ # Import all 'public' attributes (avoid those starting with _)
+ declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {}))
+
+ # Import attributes from the class definition
+ attributes = postgen_declarations.update_with_public(attributes)
+ # Store protected/private attributes in 'non_factory_attrs'.
+ attributes = declarations.update_with_public(attributes)
+
+ # Store the DeclarationDict in the attributes of the newly created class
+ attributes[CLASS_ATTRIBUTE_DECLARATIONS] = declarations
+ attributes[CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS] = postgen_declarations
+
+ return attributes
+
+ def __new__(cls, class_name, bases, attrs, extra_attrs=None):
+ """Record attributes as a pattern for later instance construction.
+
+ This is called when a new Factory subclass is defined; it will collect
+ attribute declaration from the class definition.
+
+ Args:
+ class_name (str): the name of the class being created
+ bases (list of class): the parents of the class being created
+ attrs (str => obj dict): the attributes as defined in the class
+ definition
+ extra_attrs (str => obj dict): extra attributes that should not be
+ included in the factory defaults, even if public. This
+ argument is only provided by extensions of this metaclass.
+
+ Returns:
+ A new class
+ """
+ parent_factories = get_factory_bases(bases)
+ if not parent_factories:
return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attrs)
- base = parent_factories[0]
+ is_abstract = attrs.pop('ABSTRACT_FACTORY', False)
+ extra_attrs = {}
+
+ if not is_abstract:
+
+ base = parent_factories[0]
- inherited_associated_class = getattr(base,
- CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None)
- associated_class = cls._discover_associated_class(class_name, attrs,
- inherited_associated_class)
+ inherited_associated_class = getattr(base,
+ CLASS_ATTRIBUTE_ASSOCIATED_CLASS, None)
+ associated_class = cls._discover_associated_class(class_name, attrs,
+ inherited_associated_class)
- # If inheriting the factory from a parent, keep a link to it.
- # This allows to use the sequence counters from the parents.
- if associated_class == inherited_associated_class:
- attrs['_base_factory'] = base
+ # If inheriting the factory from a parent, keep a link to it.
+ # This allows to use the sequence counters from the parents.
+ if associated_class == inherited_associated_class:
+ attrs['_base_factory'] = base
- # The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into account
- # when parsing the declared attributes of the new class.
- extra_attrs = {CLASS_ATTRIBUTE_ASSOCIATED_CLASS: associated_class}
+ # The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into account
+ # when parsing the declared attributes of the new class.
+ extra_attrs = {CLASS_ATTRIBUTE_ASSOCIATED_CLASS: associated_class}
- return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attrs, extra_attrs=extra_attrs)
+ # Extract pre- and post-generation declarations
+ attributes = cls._extract_declarations(parent_factories, attrs)
+
+ # Add extra args if provided.
+ if extra_attrs:
+ attributes.update(extra_attrs)
+
+ return super(FactoryMetaClass, cls).__new__(cls, class_name, bases, attributes)
def __str__(self):
return '<%s for %s>' % (self.__name__,
getattr(self, CLASS_ATTRIBUTE_ASSOCIATED_CLASS).__name__)
+
# Factory base classes
class BaseFactory(object):
@@ -430,10 +431,8 @@ class BaseFactory(object):
return cls.generate_batch(strategy, size, **kwargs)
-class StubFactory(BaseFactory):
- __metaclass__ = BaseFactoryMetaClass
-
- default_strategy = STUB_STRATEGY
+class AssociatedClassError(RuntimeError):
+ pass
class Factory(BaseFactory):
@@ -444,10 +443,8 @@ class Factory(BaseFactory):
"""
__metaclass__ = FactoryMetaClass
- default_strategy = CREATE_STRATEGY
-
- class AssociatedClassError(RuntimeError):
- pass
+ ABSTRACT_FACTORY = True
+ FACTORY_STRATEGY = CREATE_STRATEGY
@classmethod
def _adjust_kwargs(cls, **kwargs):
@@ -522,6 +519,19 @@ class Factory(BaseFactory):
return cls._generate(True, attrs)
+Factory.AssociatedClassError = AssociatedClassError
+
+
+class StubFactory(BaseFactory):
+ __metaclass__ = FactoryMetaClass
+
+ FACTORY_STRATEGY = STUB_STRATEGY
+
+ FACTORY_FOR = containers.StubObject
+
+print StubFactory._associated_class
+
+
class DjangoModelFactory(Factory):
"""Factory for Django models.
@@ -643,6 +653,6 @@ def use_strategy(new_strategy):
This is an alternative to setting default_strategy in the class definition.
"""
def wrapped_class(klass):
- klass.default_strategy = new_strategy
+ klass.FACTORY_STRATEGY = new_strategy
return klass
return wrapped_class
diff --git a/tests/test_base.py b/tests/test_base.py
index e86eae3..e12c0ae 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -108,13 +108,13 @@ class FactoryTestCase(unittest.TestCase):
class FactoryDefaultStrategyTestCase(unittest.TestCase):
def setUp(self):
- self.default_strategy = base.Factory.default_strategy
+ self.default_strategy = base.Factory.FACTORY_STRATEGY
def tearDown(self):
- base.Factory.default_strategy = self.default_strategy
+ base.Factory.FACTORY_STRATEGY = self.default_strategy
def testBuildStrategy(self):
- base.Factory.default_strategy = base.BUILD_STRATEGY
+ base.Factory.FACTORY_STRATEGY = base.BUILD_STRATEGY
class TestModelFactory(base.Factory):
FACTORY_FOR = TestModel
@@ -126,7 +126,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
self.assertFalse(test_model.id)
def testCreateStrategy(self):
- # Default default_strategy
+ # Default FACTORY_STRATEGY
class TestModelFactory(FakeModelFactory):
FACTORY_FOR = TestModel
@@ -138,7 +138,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
self.assertTrue(test_model.id)
def testStubStrategy(self):
- base.Factory.default_strategy = base.STUB_STRATEGY
+ base.Factory.FACTORY_STRATEGY = base.STUB_STRATEGY
class TestModelFactory(base.Factory):
FACTORY_FOR = TestModel
@@ -150,7 +150,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
self.assertFalse(hasattr(test_model, 'id')) # We should have a plain old object
def testUnknownStrategy(self):
- base.Factory.default_strategy = 'unknown'
+ base.Factory.FACTORY_STRATEGY = 'unknown'
class TestModelFactory(base.Factory):
FACTORY_FOR = TestModel
@@ -165,11 +165,11 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
one = 'one'
- TestModelFactory.default_strategy = base.CREATE_STRATEGY
+ TestModelFactory.FACTORY_STRATEGY = base.CREATE_STRATEGY
self.assertRaises(base.StubFactory.UnsupportedStrategy, TestModelFactory)
- TestModelFactory.default_strategy = base.BUILD_STRATEGY
+ TestModelFactory.FACTORY_STRATEGY = base.BUILD_STRATEGY
self.assertRaises(base.StubFactory.UnsupportedStrategy, TestModelFactory)
def test_change_strategy(self):
@@ -179,7 +179,7 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
one = 'one'
- self.assertEqual(base.CREATE_STRATEGY, TestModelFactory.default_strategy)
+ self.assertEqual(base.CREATE_STRATEGY, TestModelFactory.FACTORY_STRATEGY)
class FactoryCreationTestCase(unittest.TestCase):
@@ -193,7 +193,7 @@ class FactoryCreationTestCase(unittest.TestCase):
class TestFactory(base.StubFactory):
pass
- self.assertEqual(TestFactory.default_strategy, base.STUB_STRATEGY)
+ self.assertEqual(TestFactory.FACTORY_STRATEGY, base.STUB_STRATEGY)
def testInheritanceWithStub(self):
class TestObjectFactory(base.StubFactory):
@@ -204,7 +204,7 @@ class FactoryCreationTestCase(unittest.TestCase):
class TestFactory(TestObjectFactory):
pass
- self.assertEqual(TestFactory.default_strategy, base.STUB_STRATEGY)
+ self.assertEqual(TestFactory.FACTORY_STRATEGY, base.STUB_STRATEGY)
def testCustomCreation(self):
class TestModelFactory(FakeModelFactory):