summaryrefslogtreecommitdiff
path: root/factory/declarations.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/declarations.py')
-rw-r--r--factory/declarations.py39
1 files changed, 29 insertions, 10 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index dbade97..f068c0d 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -430,11 +430,18 @@ class PostGenerationDeclaration(object):
(object, dict): a tuple containing the attribute at 'name' (if
provided) and a dict of extracted attributes
"""
- extracted = attrs.pop(name, None)
+ try:
+ extracted = attrs.pop(name)
+ did_extract = True
+ except KeyError:
+ extracted = None
+ did_extract = False
+
kwargs = utils.extract_dict(name, attrs)
- return extracted, kwargs
+ return did_extract, extracted, kwargs
- def call(self, obj, create, extracted=None, **kwargs): # pragma: no cover
+ def call(self, obj, create,
+ extracted=None, factory_extracted=False, **kwargs): # pragma: no cover
"""Call this hook; no return value is expected.
Args:
@@ -454,7 +461,8 @@ class PostGeneration(PostGenerationDeclaration):
super(PostGeneration, self).__init__()
self.function = function
- def call(self, obj, create, extracted=None, **kwargs):
+ def call(self, obj, create,
+ extracted=None, factory_extracted=False, **kwargs):
logger.debug('PostGeneration: Calling %s.%s(%s)',
self.function.__module__,
self.function.__name__,
@@ -492,19 +500,29 @@ class RelatedFactory(PostGenerationDeclaration):
"""Retrieve the wrapped factory.Factory subclass."""
return self.factory_wrapper.get()
- def call(self, obj, create, extracted=None, **kwargs):
+ def call(self, obj, create,
+ extracted=None, factory_extracted=False, **kwargs):
+ factory = self.get_factory()
+
+ if factory_extracted:
+ # The user passed in a custom value
+ logger.debug('RelatedFactory: Using provided %r instead of '
+ 'generating %s.%s.',
+ extracted, factory.__module__, factory.__name__,
+ )
+ return extracted
+
passed_kwargs = dict(self.defaults)
passed_kwargs.update(kwargs)
if self.name:
passed_kwargs[self.name] = obj
- factory = self.get_factory()
logger.debug('RelatedFactory: Generating %s.%s(%s)',
factory.__module__,
factory.__name__,
utils.log_pprint((create,), passed_kwargs),
)
- factory.simple_generate(create, **passed_kwargs)
+ return factory.simple_generate(create, **passed_kwargs)
class PostGenerationMethodCall(PostGenerationDeclaration):
@@ -526,8 +544,9 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
self.method_args = args
self.method_kwargs = kwargs
- def call(self, obj, create, extracted=None, **kwargs):
- if extracted is None:
+ def call(self, obj, create,
+ extracted=None, factory_extracted=False, **kwargs):
+ if not factory_extracted:
passed_args = self.method_args
elif len(self.method_args) <= 1:
@@ -544,4 +563,4 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
self.method_name,
utils.log_pprint(passed_args, passed_kwargs),
)
- method(*passed_args, **passed_kwargs)
+ return method(*passed_args, **passed_kwargs)