From 94d7defa820b69152fb5aeadb3f5ccc3611158fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Barrois?= Date: Sat, 15 Jun 2013 15:09:38 +0200 Subject: Cleanup PostGenerationDeclaration extraction context. --- factory/declarations.py | 60 ++++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 23 deletions(-) (limited to 'factory/declarations.py') diff --git a/factory/declarations.py b/factory/declarations.py index 552ddf2..e7c8ed5 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -417,6 +417,21 @@ class List(SubFactory): **params) +class ExtractionContext(object): + """Private class holding all required context from extraction to postgen.""" + def __init__(self, value=None, did_extract=False, extra=None): + self.value = value + self.did_extract = did_extract + self.extra = extra or {} + + def __repr__(self): + return 'ExtractionContext(%r, %r, %r)' % ( + self.value, + self.did_extract, + self.extra, + ) + + class PostGenerationDeclaration(object): """Declarations to be called once the target object has been generated.""" @@ -441,19 +456,16 @@ class PostGenerationDeclaration(object): did_extract = False kwargs = utils.extract_dict(name, attrs) - return did_extract, extracted, kwargs + return ExtractionContext(extracted, did_extract, kwargs) - def call(self, obj, create, - extracted=None, factory_extracted=False, **kwargs): # pragma: no cover + def call(self, obj, create, extraction_context): # pragma: no cover """Call this hook; no return value is expected. Args: obj (object): the newly generated object create (bool): whether the object was 'built' or 'created' - extracted (object): the value given for in the - object definition, or None if not provided. - kwargs (dict): declarations extracted from the object - definition for this hook + extraction_context: An ExtractionContext containing values + extracted from the containing factory's declaration """ raise NotImplementedError() @@ -464,14 +476,17 @@ class PostGeneration(PostGenerationDeclaration): super(PostGeneration, self).__init__() self.function = function - def call(self, obj, create, - extracted=None, factory_extracted=False, **kwargs): + def call(self, obj, create, extraction_context): logger.debug('PostGeneration: Calling %s.%s(%s)', self.function.__module__, self.function.__name__, - utils.log_pprint((obj, create, extracted), kwargs), + utils.log_pprint( + (obj, create, extraction_context.value), + extraction_context.extra, + ), ) - return self.function(obj, create, extracted, **kwargs) + return self.function(obj, create, + extraction_context.value, **extraction_context.extra) class RelatedFactory(PostGenerationDeclaration): @@ -503,20 +518,20 @@ class RelatedFactory(PostGenerationDeclaration): """Retrieve the wrapped factory.Factory subclass.""" return self.factory_wrapper.get() - def call(self, obj, create, - extracted=None, factory_extracted=False, **kwargs): + def call(self, obj, create, extraction_context): factory = self.get_factory() - if factory_extracted: + if extraction_context.did_extract: # The user passed in a custom value logger.debug('RelatedFactory: Using provided %r instead of ' 'generating %s.%s.', - extracted, factory.__module__, factory.__name__, + extraction_context.value, + factory.__module__, factory.__name__, ) - return extracted + return extraction_context.value passed_kwargs = dict(self.defaults) - passed_kwargs.update(kwargs) + passed_kwargs.update(extraction_context.extra) if self.name: passed_kwargs[self.name] = obj @@ -547,19 +562,18 @@ class PostGenerationMethodCall(PostGenerationDeclaration): self.method_args = args self.method_kwargs = kwargs - def call(self, obj, create, - extracted=None, factory_extracted=False, **kwargs): - if not factory_extracted: + def call(self, obj, create, extraction_context): + if not extraction_context.did_extract: passed_args = self.method_args elif len(self.method_args) <= 1: # Max one argument expected - passed_args = (extracted,) + passed_args = (extraction_context.value,) else: - passed_args = tuple(extracted) + passed_args = tuple(extraction_context.value) passed_kwargs = dict(self.method_kwargs) - passed_kwargs.update(kwargs) + passed_kwargs.update(extraction_context.extra) method = getattr(obj, self.method_name) logger.debug('PostGenerationMethodCall: Calling %r.%s(%s)', obj, -- cgit v1.2.3