summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polyconseil.fr>2012-04-13 19:42:19 +0200
committerRaphaël Barrois <raphael.barrois@polyconseil.fr>2012-04-15 10:06:39 +0200
commit0e7fed312bf2de6d628a61b116bd91e04bf0a9ff (patch)
tree103408b5bee8339eee6e510300a374b0bcea748f
parentb590e5014351a79d66d2f4816b1a6aa83908f395 (diff)
downloadfactory-boy-0e7fed312bf2de6d628a61b116bd91e04bf0a9ff.tar
factory-boy-0e7fed312bf2de6d628a61b116bd91e04bf0a9ff.tar.gz
Handle the PostGeneration declarations.
Signed-off-by: Raphaël Barrois <raphael.barrois@polyconseil.fr>
-rw-r--r--factory/__init__.py1
-rw-r--r--factory/base.py44
-rw-r--r--factory/declarations.py6
-rw-r--r--tests/test_using.py33
4 files changed, 72 insertions, 12 deletions
diff --git a/factory/__init__.py b/factory/__init__.py
index dd91343..1bf8968 100644
--- a/factory/__init__.py
+++ b/factory/__init__.py
@@ -67,5 +67,6 @@ from declarations import (
sequence,
lazy_attribute_sequence,
container_attribute,
+ post_declaration,
)
diff --git a/factory/base.py b/factory/base.py
index 98029e3..c1dbd98 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -44,6 +44,7 @@ FACTORY_CLASS_DECLARATION = 'FACTORY_FOR'
# Factory class attributes
CLASS_ATTRIBUTE_DECLARATIONS = '_declarations'
+CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS = '_postgen_declarations'
CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class'
@@ -97,18 +98,24 @@ class BaseFactoryMetaClass(type):
return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs)
declarations = containers.DeclarationDict()
+ postgen_declarations = containers.PostGenerationDeclarationDict()
# Add parent declarations in reverse order.
for base in reversed(parent_factories):
+ # Import parent PostGenerationDeclaration
+ postgen_declarations.update_with_public(
+ getattr(base, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS, {}))
# Import all 'public' attributes (avoid those starting with _)
declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {}))
- # Import attributes from the class definition, storing protected/private
- # attributes in 'non_factory_attrs'.
- non_factory_attrs = declarations.update_with_public(attrs)
+ # Import attributes from the class definition
+ non_postgen_attrs = postgen_declarations.update_with_public(attrs)
+ # Store protected/private attributes in 'non_factory_attrs'.
+ non_factory_attrs = declarations.update_with_public(non_postgen_attrs)
# Store the DeclarationDict in the attributes of the newly created class
non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations
+ non_factory_attrs[CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS] = postgen_declarations
# Add extra args if provided.
if extra_attrs:
@@ -521,20 +528,37 @@ class Factory(BaseFactory):
return cls.get_building_function()(getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS), **kwargs)
@classmethod
- def _build(cls, **kwargs):
- return cls._prepare(create=False, **kwargs)
+ def _generate(cls, create, attrs):
+ """generate the object.
- @classmethod
- def _create(cls, **kwargs):
- return cls._prepare(create=True, **kwargs)
+ Args:
+ create (bool): whether to 'build' or 'create' the object
+ attrs (dict): attributes to use for generating the object
+ """
+ # Extract declarations used for post-generation
+ postgen_declarations = getattr(cls, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS)
+ postgen_attributes = {}
+ for name, decl in sorted(postgen_declarations.items()):
+ postgen_attributes[name] = decl.extract(name, attrs)
+
+ # Generate the object
+ obj = cls._prepare(create, **attrs)
+
+ # Handle post-generation attributes
+ for name, decl in sorted(postgen_declarations.items()):
+ extracted, extracted_kwargs = postgen_attributes[name]
+ decl.call(obj, create, extracted, **extracted_kwargs)
+ return obj
@classmethod
def build(cls, **kwargs):
- return cls._build(**cls.attributes(create=False, extra=kwargs))
+ attrs = cls.attributes(create=False, extra=kwargs)
+ return cls._generate(False, attrs)
@classmethod
def create(cls, **kwargs):
- return cls._create(**cls.attributes(create=True, extra=kwargs))
+ attrs = cls.attributes(create=True, extra=kwargs)
+ return cls._generate(True, attrs)
class DjangoModelFactory(Factory):
diff --git a/factory/declarations.py b/factory/declarations.py
index d9d560a..ddc6c78 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -23,6 +23,8 @@
import itertools
+from factory import utils
+
class OrderedDeclaration(object):
"""A factory declaration.
@@ -275,12 +277,12 @@ class PostGenerationDeclaration(object):
(object, dict): a tuple containing the attribute at 'name' (if
provided) and a dict of extracted attributes
"""
- extracted = attrs.get(name)
+ extracted = attrs.pop(name, None)
if self.extract_prefix:
extract_prefix = self.extract_prefix
else:
extract_prefix = name
- kwargs = utils.extract_dict(extract_prefix)
+ kwargs = utils.extract_dict(extract_prefix, attrs)
return extracted, kwargs
def call(self, obj, create, extracted=None, **kwargs):
diff --git a/tests/test_using.py b/tests/test_using.py
index a3cf89c..54106a9 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -835,6 +835,39 @@ class IteratorTestCase(unittest.TestCase):
self.assertEqual(i % 5, obj.one)
+class PostDeclarationHookTestCase(unittest.TestCase):
+ def test_post_declaration(self):
+ class TestObjectFactory(factory.Factory):
+ one = 1
+
+ @factory.post_declaration()
+ def incr_one(self, _create, _increment):
+ self.one += 1
+
+ obj = TestObjectFactory.build()
+ self.assertEqual(2, obj.one)
+ self.assertFalse(hasattr(obj, 'incr_one'))
+
+ obj = TestObjectFactory.build(one=2)
+ self.assertEqual(3, obj.one)
+ self.assertFalse(hasattr(obj, 'incr_one'))
+
+ def test_post_declaration_extraction(self):
+ class TestObjectFactory(factory.Factory):
+ one = 1
+
+ @factory.post_declaration()
+ def incr_one(self, _create, increment=1):
+ self.one += increment
+
+ obj = TestObjectFactory.build(incr_one=2)
+ self.assertEqual(3, obj.one)
+ self.assertFalse(hasattr(obj, 'incr_one'))
+
+ obj = TestObjectFactory.build(one=2, incr_one=2)
+ self.assertEqual(4, obj.one)
+ self.assertFalse(hasattr(obj, 'incr_one'))
+
if __name__ == '__main__':
unittest.main()