summaryrefslogtreecommitdiff
path: root/factory/declarations.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/declarations.py')
-rw-r--r--factory/declarations.py60
1 files changed, 37 insertions, 23 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 552ddf2..e7c8ed5 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -417,6 +417,21 @@ class List(SubFactory):
**params)
+class ExtractionContext(object):
+ """Private class holding all required context from extraction to postgen."""
+ def __init__(self, value=None, did_extract=False, extra=None):
+ self.value = value
+ self.did_extract = did_extract
+ self.extra = extra or {}
+
+ def __repr__(self):
+ return 'ExtractionContext(%r, %r, %r)' % (
+ self.value,
+ self.did_extract,
+ self.extra,
+ )
+
+
class PostGenerationDeclaration(object):
"""Declarations to be called once the target object has been generated."""
@@ -441,19 +456,16 @@ class PostGenerationDeclaration(object):
did_extract = False
kwargs = utils.extract_dict(name, attrs)
- return did_extract, extracted, kwargs
+ return ExtractionContext(extracted, did_extract, kwargs)
- def call(self, obj, create,
- extracted=None, factory_extracted=False, **kwargs): # pragma: no cover
+ def call(self, obj, create, extraction_context): # pragma: no cover
"""Call this hook; no return value is expected.
Args:
obj (object): the newly generated object
create (bool): whether the object was 'built' or 'created'
- extracted (object): the value given for <name> in the
- object definition, or None if not provided.
- kwargs (dict): declarations extracted from the object
- definition for this hook
+ extraction_context: An ExtractionContext containing values
+ extracted from the containing factory's declaration
"""
raise NotImplementedError()
@@ -464,14 +476,17 @@ class PostGeneration(PostGenerationDeclaration):
super(PostGeneration, self).__init__()
self.function = function
- def call(self, obj, create,
- extracted=None, factory_extracted=False, **kwargs):
+ def call(self, obj, create, extraction_context):
logger.debug('PostGeneration: Calling %s.%s(%s)',
self.function.__module__,
self.function.__name__,
- utils.log_pprint((obj, create, extracted), kwargs),
+ utils.log_pprint(
+ (obj, create, extraction_context.value),
+ extraction_context.extra,
+ ),
)
- return self.function(obj, create, extracted, **kwargs)
+ return self.function(obj, create,
+ extraction_context.value, **extraction_context.extra)
class RelatedFactory(PostGenerationDeclaration):
@@ -503,20 +518,20 @@ class RelatedFactory(PostGenerationDeclaration):
"""Retrieve the wrapped factory.Factory subclass."""
return self.factory_wrapper.get()
- def call(self, obj, create,
- extracted=None, factory_extracted=False, **kwargs):
+ def call(self, obj, create, extraction_context):
factory = self.get_factory()
- if factory_extracted:
+ if extraction_context.did_extract:
# The user passed in a custom value
logger.debug('RelatedFactory: Using provided %r instead of '
'generating %s.%s.',
- extracted, factory.__module__, factory.__name__,
+ extraction_context.value,
+ factory.__module__, factory.__name__,
)
- return extracted
+ return extraction_context.value
passed_kwargs = dict(self.defaults)
- passed_kwargs.update(kwargs)
+ passed_kwargs.update(extraction_context.extra)
if self.name:
passed_kwargs[self.name] = obj
@@ -547,19 +562,18 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
self.method_args = args
self.method_kwargs = kwargs
- def call(self, obj, create,
- extracted=None, factory_extracted=False, **kwargs):
- if not factory_extracted:
+ def call(self, obj, create, extraction_context):
+ if not extraction_context.did_extract:
passed_args = self.method_args
elif len(self.method_args) <= 1:
# Max one argument expected
- passed_args = (extracted,)
+ passed_args = (extraction_context.value,)
else:
- passed_args = tuple(extracted)
+ passed_args = tuple(extraction_context.value)
passed_kwargs = dict(self.method_kwargs)
- passed_kwargs.update(kwargs)
+ passed_kwargs.update(extraction_context.extra)
method = getattr(obj, self.method_name)
logger.debug('PostGenerationMethodCall: Calling %r.%s(%s)',
obj,