diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-06-14 23:36:25 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-06-14 23:36:25 +0200 |
commit | 1ba20b0ed7b920fa2d161df94a0dda3d93b1e14b (patch) | |
tree | bcf4d4291b0f451dce533ff244f87fb088b90a34 | |
parent | 251ae29b4beedd7e9af721ceabb82a03f2d55bab (diff) | |
download | factory-boy-1ba20b0ed7b920fa2d161df94a0dda3d93b1e14b.tar factory-boy-1ba20b0ed7b920fa2d161df94a0dda3d93b1e14b.tar.gz |
Properly handle passed-in None in RelatedFactory (Closes #62).
Thanks to @Dhekke for the help!
-rw-r--r-- | docs/reference.rst | 16 | ||||
-rw-r--r-- | factory/base.py | 4 | ||||
-rw-r--r-- | factory/declarations.py | 39 | ||||
-rw-r--r-- | tests/test_declarations.py | 19 |
4 files changed, 59 insertions, 19 deletions
diff --git a/docs/reference.rst b/docs/reference.rst index a2d6c9a..74f2dbd 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -1157,6 +1157,22 @@ Extra kwargs may be passed to the related factory, through the usual ``ATTR__SUB >>> City.objects.get(capital_of=england) <City: London> +If a value if passed for the :class:`RelatedFactory` attribute, this disables +:class:`RelatedFactory` generation: + +.. code-block:: pycon + + >>> france = CountryFactory() + >>> paris = City.objects.get() + >>> paris + <City: Paris> + >>> reunion = CountryFactory(capital_city=paris) + >>> City.objects.count() # No new capital_city generated + 1 + >>> guyane = CountryFactory(capital_city=paris, capital_city__name='Kourou') + >>> City.objects.count() # No new capital_city generated, ``name`` ignored. + 1 + PostGeneration """""""""""""" diff --git a/factory/base.py b/factory/base.py index 60aa218..76d3d4a 100644 --- a/factory/base.py +++ b/factory/base.py @@ -380,9 +380,9 @@ class BaseFactory(object): # Handle post-generation attributes results = {} for name, decl in sorted(postgen_declarations.items()): - extracted, extracted_kwargs = postgen_attributes[name] + did_extract, extracted, extracted_kwargs = postgen_attributes[name] results[name] = decl.call(obj, create, extracted, - **extracted_kwargs) + factory_extracted=did_extract, **extracted_kwargs) cls._after_postgeneration(obj, create, results) diff --git a/factory/declarations.py b/factory/declarations.py index dbade97..f068c0d 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -430,11 +430,18 @@ class PostGenerationDeclaration(object): (object, dict): a tuple containing the attribute at 'name' (if provided) and a dict of extracted attributes """ - extracted = attrs.pop(name, None) + try: + extracted = attrs.pop(name) + did_extract = True + except KeyError: + extracted = None + did_extract = False + kwargs = utils.extract_dict(name, attrs) - return extracted, kwargs + return did_extract, extracted, kwargs - def call(self, obj, create, extracted=None, **kwargs): # pragma: no cover + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): # pragma: no cover """Call this hook; no return value is expected. Args: @@ -454,7 +461,8 @@ class PostGeneration(PostGenerationDeclaration): super(PostGeneration, self).__init__() self.function = function - def call(self, obj, create, extracted=None, **kwargs): + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): logger.debug('PostGeneration: Calling %s.%s(%s)', self.function.__module__, self.function.__name__, @@ -492,19 +500,29 @@ class RelatedFactory(PostGenerationDeclaration): """Retrieve the wrapped factory.Factory subclass.""" return self.factory_wrapper.get() - def call(self, obj, create, extracted=None, **kwargs): + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): + factory = self.get_factory() + + if factory_extracted: + # The user passed in a custom value + logger.debug('RelatedFactory: Using provided %r instead of ' + 'generating %s.%s.', + extracted, factory.__module__, factory.__name__, + ) + return extracted + passed_kwargs = dict(self.defaults) passed_kwargs.update(kwargs) if self.name: passed_kwargs[self.name] = obj - factory = self.get_factory() logger.debug('RelatedFactory: Generating %s.%s(%s)', factory.__module__, factory.__name__, utils.log_pprint((create,), passed_kwargs), ) - factory.simple_generate(create, **passed_kwargs) + return factory.simple_generate(create, **passed_kwargs) class PostGenerationMethodCall(PostGenerationDeclaration): @@ -526,8 +544,9 @@ class PostGenerationMethodCall(PostGenerationDeclaration): self.method_args = args self.method_kwargs = kwargs - def call(self, obj, create, extracted=None, **kwargs): - if extracted is None: + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): + if not factory_extracted: passed_args = self.method_args elif len(self.method_args) <= 1: @@ -544,4 +563,4 @@ class PostGenerationMethodCall(PostGenerationDeclaration): self.method_name, utils.log_pprint(passed_args, passed_kwargs), ) - method(*passed_args, **passed_kwargs) + return method(*passed_args, **passed_kwargs) diff --git a/tests/test_declarations.py b/tests/test_declarations.py index e0b2513..cd38dd2 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -119,7 +119,9 @@ class PostGenerationDeclarationTestCase(unittest.TestCase): def test_extract_no_prefix(self): decl = declarations.PostGenerationDeclaration() - extracted, kwargs = decl.extract('foo', {'foo': 13, 'foo__bar': 42}) + did_extract, extracted, kwargs = decl.extract('foo', + {'foo': 13, 'foo__bar': 42}) + self.assertTrue(did_extract) self.assertEqual(extracted, 13) self.assertEqual(kwargs, {'bar': 42}) @@ -130,8 +132,9 @@ class PostGenerationDeclarationTestCase(unittest.TestCase): call_params.append(args) call_params.append(kwargs) - extracted, kwargs = foo.extract('foo', + did_extract, extracted, kwargs = foo.extract('foo', {'foo': 13, 'foo__bar': 42, 'blah': 42, 'blah__baz': 1}) + self.assertTrue(did_extract) self.assertEqual(13, extracted) self.assertEqual({'bar': 42}, kwargs) @@ -215,17 +218,17 @@ class PostGenerationMethodCallTestCase(unittest.TestCase): def test_call_with_passed_extracted_string(self): decl = declarations.PostGenerationMethodCall( 'method') - decl.call(self.obj, False, 'data') + decl.call(self.obj, False, 'data', factory_extracted=True) self.obj.method.assert_called_once_with('data') def test_call_with_passed_extracted_int(self): decl = declarations.PostGenerationMethodCall('method') - decl.call(self.obj, False, 1) + decl.call(self.obj, False, 1, factory_extracted=True) self.obj.method.assert_called_once_with(1) def test_call_with_passed_extracted_iterable(self): decl = declarations.PostGenerationMethodCall('method') - decl.call(self.obj, False, (1, 2, 3)) + decl.call(self.obj, False, (1, 2, 3), factory_extracted=True) self.obj.method.assert_called_once_with((1, 2, 3)) def test_call_with_method_kwargs(self): @@ -248,13 +251,15 @@ class PostGenerationMethodCallTestCase(unittest.TestCase): def test_multi_call_with_passed_multiple_args(self): decl = declarations.PostGenerationMethodCall( 'method', 'arg1', 'arg2') - decl.call(self.obj, False, ('param1', 'param2', 'param3')) + decl.call(self.obj, False, ('param1', 'param2', 'param3'), + factory_extracted=True) self.obj.method.assert_called_once_with('param1', 'param2', 'param3') def test_multi_call_with_passed_tuple(self): decl = declarations.PostGenerationMethodCall( 'method', 'arg1', 'arg2') - decl.call(self.obj, False, (('param1', 'param2'),)) + decl.call(self.obj, False, (('param1', 'param2'),), + factory_extracted=True) self.obj.method.assert_called_once_with(('param1', 'param2')) def test_multi_call_with_kwargs(self): |