diff options
author | Raphaël Barrois <raphael.barrois@polyconseil.fr> | 2012-04-13 19:42:19 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polyconseil.fr> | 2012-04-15 10:06:39 +0200 |
commit | 0e7fed312bf2de6d628a61b116bd91e04bf0a9ff (patch) | |
tree | 103408b5bee8339eee6e510300a374b0bcea748f | |
parent | b590e5014351a79d66d2f4816b1a6aa83908f395 (diff) | |
download | factory-boy-0e7fed312bf2de6d628a61b116bd91e04bf0a9ff.tar factory-boy-0e7fed312bf2de6d628a61b116bd91e04bf0a9ff.tar.gz |
Handle the PostGeneration declarations.
Signed-off-by: Raphaël Barrois <raphael.barrois@polyconseil.fr>
-rw-r--r-- | factory/__init__.py | 1 | ||||
-rw-r--r-- | factory/base.py | 44 | ||||
-rw-r--r-- | factory/declarations.py | 6 | ||||
-rw-r--r-- | tests/test_using.py | 33 |
4 files changed, 72 insertions, 12 deletions
diff --git a/factory/__init__.py b/factory/__init__.py index dd91343..1bf8968 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -67,5 +67,6 @@ from declarations import ( sequence, lazy_attribute_sequence, container_attribute, + post_declaration, ) diff --git a/factory/base.py b/factory/base.py index 98029e3..c1dbd98 100644 --- a/factory/base.py +++ b/factory/base.py @@ -44,6 +44,7 @@ FACTORY_CLASS_DECLARATION = 'FACTORY_FOR' # Factory class attributes CLASS_ATTRIBUTE_DECLARATIONS = '_declarations' +CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS = '_postgen_declarations' CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class' @@ -97,18 +98,24 @@ class BaseFactoryMetaClass(type): return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) declarations = containers.DeclarationDict() + postgen_declarations = containers.PostGenerationDeclarationDict() # Add parent declarations in reverse order. for base in reversed(parent_factories): + # Import parent PostGenerationDeclaration + postgen_declarations.update_with_public( + getattr(base, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS, {})) # Import all 'public' attributes (avoid those starting with _) declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) - # Import attributes from the class definition, storing protected/private - # attributes in 'non_factory_attrs'. - non_factory_attrs = declarations.update_with_public(attrs) + # Import attributes from the class definition + non_postgen_attrs = postgen_declarations.update_with_public(attrs) + # Store protected/private attributes in 'non_factory_attrs'. + non_factory_attrs = declarations.update_with_public(non_postgen_attrs) # Store the DeclarationDict in the attributes of the newly created class non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations + non_factory_attrs[CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS] = postgen_declarations # Add extra args if provided. if extra_attrs: @@ -521,20 +528,37 @@ class Factory(BaseFactory): return cls.get_building_function()(getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS), **kwargs) @classmethod - def _build(cls, **kwargs): - return cls._prepare(create=False, **kwargs) + def _generate(cls, create, attrs): + """generate the object. - @classmethod - def _create(cls, **kwargs): - return cls._prepare(create=True, **kwargs) + Args: + create (bool): whether to 'build' or 'create' the object + attrs (dict): attributes to use for generating the object + """ + # Extract declarations used for post-generation + postgen_declarations = getattr(cls, CLASS_ATTRIBUTE_POSTGEN_DECLARATIONS) + postgen_attributes = {} + for name, decl in sorted(postgen_declarations.items()): + postgen_attributes[name] = decl.extract(name, attrs) + + # Generate the object + obj = cls._prepare(create, **attrs) + + # Handle post-generation attributes + for name, decl in sorted(postgen_declarations.items()): + extracted, extracted_kwargs = postgen_attributes[name] + decl.call(obj, create, extracted, **extracted_kwargs) + return obj @classmethod def build(cls, **kwargs): - return cls._build(**cls.attributes(create=False, extra=kwargs)) + attrs = cls.attributes(create=False, extra=kwargs) + return cls._generate(False, attrs) @classmethod def create(cls, **kwargs): - return cls._create(**cls.attributes(create=True, extra=kwargs)) + attrs = cls.attributes(create=True, extra=kwargs) + return cls._generate(True, attrs) class DjangoModelFactory(Factory): diff --git a/factory/declarations.py b/factory/declarations.py index d9d560a..ddc6c78 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -23,6 +23,8 @@ import itertools +from factory import utils + class OrderedDeclaration(object): """A factory declaration. @@ -275,12 +277,12 @@ class PostGenerationDeclaration(object): (object, dict): a tuple containing the attribute at 'name' (if provided) and a dict of extracted attributes """ - extracted = attrs.get(name) + extracted = attrs.pop(name, None) if self.extract_prefix: extract_prefix = self.extract_prefix else: extract_prefix = name - kwargs = utils.extract_dict(extract_prefix) + kwargs = utils.extract_dict(extract_prefix, attrs) return extracted, kwargs def call(self, obj, create, extracted=None, **kwargs): diff --git a/tests/test_using.py b/tests/test_using.py index a3cf89c..54106a9 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -835,6 +835,39 @@ class IteratorTestCase(unittest.TestCase): self.assertEqual(i % 5, obj.one) +class PostDeclarationHookTestCase(unittest.TestCase): + def test_post_declaration(self): + class TestObjectFactory(factory.Factory): + one = 1 + + @factory.post_declaration() + def incr_one(self, _create, _increment): + self.one += 1 + + obj = TestObjectFactory.build() + self.assertEqual(2, obj.one) + self.assertFalse(hasattr(obj, 'incr_one')) + + obj = TestObjectFactory.build(one=2) + self.assertEqual(3, obj.one) + self.assertFalse(hasattr(obj, 'incr_one')) + + def test_post_declaration_extraction(self): + class TestObjectFactory(factory.Factory): + one = 1 + + @factory.post_declaration() + def incr_one(self, _create, increment=1): + self.one += increment + + obj = TestObjectFactory.build(incr_one=2) + self.assertEqual(3, obj.one) + self.assertFalse(hasattr(obj, 'incr_one')) + + obj = TestObjectFactory.build(one=2, incr_one=2) + self.assertEqual(4, obj.one) + self.assertFalse(hasattr(obj, 'incr_one')) + if __name__ == '__main__': unittest.main() |