diff options
Diffstat (limited to 'factory/declarations.py')
-rw-r--r-- | factory/declarations.py | 39 |
1 files changed, 29 insertions, 10 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index dbade97..f068c0d 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -430,11 +430,18 @@ class PostGenerationDeclaration(object): (object, dict): a tuple containing the attribute at 'name' (if provided) and a dict of extracted attributes """ - extracted = attrs.pop(name, None) + try: + extracted = attrs.pop(name) + did_extract = True + except KeyError: + extracted = None + did_extract = False + kwargs = utils.extract_dict(name, attrs) - return extracted, kwargs + return did_extract, extracted, kwargs - def call(self, obj, create, extracted=None, **kwargs): # pragma: no cover + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): # pragma: no cover """Call this hook; no return value is expected. Args: @@ -454,7 +461,8 @@ class PostGeneration(PostGenerationDeclaration): super(PostGeneration, self).__init__() self.function = function - def call(self, obj, create, extracted=None, **kwargs): + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): logger.debug('PostGeneration: Calling %s.%s(%s)', self.function.__module__, self.function.__name__, @@ -492,19 +500,29 @@ class RelatedFactory(PostGenerationDeclaration): """Retrieve the wrapped factory.Factory subclass.""" return self.factory_wrapper.get() - def call(self, obj, create, extracted=None, **kwargs): + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): + factory = self.get_factory() + + if factory_extracted: + # The user passed in a custom value + logger.debug('RelatedFactory: Using provided %r instead of ' + 'generating %s.%s.', + extracted, factory.__module__, factory.__name__, + ) + return extracted + passed_kwargs = dict(self.defaults) passed_kwargs.update(kwargs) if self.name: passed_kwargs[self.name] = obj - factory = self.get_factory() logger.debug('RelatedFactory: Generating %s.%s(%s)', factory.__module__, factory.__name__, utils.log_pprint((create,), passed_kwargs), ) - factory.simple_generate(create, **passed_kwargs) + return factory.simple_generate(create, **passed_kwargs) class PostGenerationMethodCall(PostGenerationDeclaration): @@ -526,8 +544,9 @@ class PostGenerationMethodCall(PostGenerationDeclaration): self.method_args = args self.method_kwargs = kwargs - def call(self, obj, create, extracted=None, **kwargs): - if extracted is None: + def call(self, obj, create, + extracted=None, factory_extracted=False, **kwargs): + if not factory_extracted: passed_args = self.method_args elif len(self.method_args) <= 1: @@ -544,4 +563,4 @@ class PostGenerationMethodCall(PostGenerationDeclaration): self.method_name, utils.log_pprint(passed_args, passed_kwargs), ) - method(*passed_args, **passed_kwargs) + return method(*passed_args, **passed_kwargs) |