From 1ba20b0ed7b920fa2d161df94a0dda3d93b1e14b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Barrois?= Date: Fri, 14 Jun 2013 23:36:25 +0200 Subject: Properly handle passed-in None in RelatedFactory (Closes #62). Thanks to @Dhekke for the help! --- factory/base.py | 4 ++-- factory/declarations.py | 39 +++++++++++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 12 deletions(-) (limited to 'factory') diff --git a/factory/base.py b/factory/base.py index 60aa218..76d3d4a 100644 --- a/factory/base.py +++ b/factory/base.py @@ -380,9 +380,9 @@ class BaseFactory(object): # Handle post-generation attributes results = {} for name, decl in sorted(postgen_declarations.items()): - extracted, extracted_kwargs = postgen_attributes[name] + did_extract, extracted, extracted_kwargs = postgen_attributes[name] results[name] = decl.call(obj, create, extracted, - **extracted_kwargs) + factory_extracted=did_extract, **extracted_kwargs) cls._after_postgeneration(obj, create, results) diff --git a/factory/declarations.py b/factory/declarations.py index dbade97..f068c0d 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -430,11 +430,18 @@ class PostGenerationDeclaration(object): (object, dict): a tuple containing the attribute at 'name' (if provided) and a dict of extracted attributes """ - extracted = attrs.pop(name, None) + try: + extracted = attrs.pop(name) + did_extract = True + except KeyError: + extracted = None + did_extract = False + kwargs = utils.extract_dict(name, attrs) - return extracted, kwargs + return did_extract, extracted, kwargs - def call(self, obj, create, extracted=None, **kwargs): # pragma: no cover + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): # pragma: no cover """Call this hook; no return value is expected. Args: @@ -454,7 +461,8 @@ class PostGeneration(PostGenerationDeclaration): super(PostGeneration, self).__init__() self.function = function - def call(self, obj, create, extracted=None, **kwargs): + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): logger.debug('PostGeneration: Calling %s.%s(%s)', self.function.__module__, self.function.__name__, @@ -492,19 +500,29 @@ class RelatedFactory(PostGenerationDeclaration): """Retrieve the wrapped factory.Factory subclass.""" return self.factory_wrapper.get() - def call(self, obj, create, extracted=None, **kwargs): + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): + factory = self.get_factory() + + if factory_extracted: + # The user passed in a custom value + logger.debug('RelatedFactory: Using provided %r instead of ' + 'generating %s.%s.', + extracted, factory.__module__, factory.__name__, + ) + return extracted + passed_kwargs = dict(self.defaults) passed_kwargs.update(kwargs) if self.name: passed_kwargs[self.name] = obj - factory = self.get_factory() logger.debug('RelatedFactory: Generating %s.%s(%s)', factory.__module__, factory.__name__, utils.log_pprint((create,), passed_kwargs), ) - factory.simple_generate(create, **passed_kwargs) + return factory.simple_generate(create, **passed_kwargs) class PostGenerationMethodCall(PostGenerationDeclaration): @@ -526,8 +544,9 @@ class PostGenerationMethodCall(PostGenerationDeclaration): self.method_args = args self.method_kwargs = kwargs - def call(self, obj, create, extracted=None, **kwargs): - if extracted is None: + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): + if not factory_extracted: passed_args = self.method_args elif len(self.method_args) <= 1: @@ -544,4 +563,4 @@ class PostGenerationMethodCall(PostGenerationDeclaration): self.method_name, utils.log_pprint(passed_args, passed_kwargs), ) - method(*passed_args, **passed_kwargs) + return method(*passed_args, **passed_kwargs) -- cgit v1.2.3