aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/base.py5
-rw-r--r--factory/declarations.py60
-rw-r--r--tests/test_declarations.py49
3 files changed, 66 insertions, 48 deletions
diff --git a/factory/base.py b/factory/base.py
index 76d3d4a..0429231 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -380,9 +380,8 @@ class BaseFactory(object):
# Handle post-generation attributes
results = {}
for name, decl in sorted(postgen_declarations.items()):
- did_extract, extracted, extracted_kwargs = postgen_attributes[name]
- results[name] = decl.call(obj, create, extracted,
- factory_extracted=did_extract, **extracted_kwargs)
+ extraction_context = postgen_attributes[name]
+ results[name] = decl.call(obj, create, extraction_context)
cls._after_postgeneration(obj, create, results)
diff --git a/factory/declarations.py b/factory/declarations.py
index 552ddf2..e7c8ed5 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -417,6 +417,21 @@ class List(SubFactory):
**params)
+class ExtractionContext(object):
+ """Private class holding all required context from extraction to postgen."""
+ def __init__(self, value=None, did_extract=False, extra=None):
+ self.value = value
+ self.did_extract = did_extract
+ self.extra = extra or {}
+
+ def __repr__(self):
+ return 'ExtractionContext(%r, %r, %r)' % (
+ self.value,
+ self.did_extract,
+ self.extra,
+ )
+
+
class PostGenerationDeclaration(object):
"""Declarations to be called once the target object has been generated."""
@@ -441,19 +456,16 @@ class PostGenerationDeclaration(object):
did_extract = False
kwargs = utils.extract_dict(name, attrs)
- return did_extract, extracted, kwargs
+ return ExtractionContext(extracted, did_extract, kwargs)
- def call(self, obj, create,
- extracted=None, factory_extracted=False, **kwargs): # pragma: no cover
+ def call(self, obj, create, extraction_context): # pragma: no cover
"""Call this hook; no return value is expected.
Args:
obj (object): the newly generated object
create (bool): whether the object was 'built' or 'created'
- extracted (object): the value given for <name> in the
- object definition, or None if not provided.
- kwargs (dict): declarations extracted from the object
- definition for this hook
+ extraction_context: An ExtractionContext containing values
+ extracted from the containing factory's declaration
"""
raise NotImplementedError()
@@ -464,14 +476,17 @@ class PostGeneration(PostGenerationDeclaration):
super(PostGeneration, self).__init__()
self.function = function
- def call(self, obj, create,
- extracted=None, factory_extracted=False, **kwargs):
+ def call(self, obj, create, extraction_context):
logger.debug('PostGeneration: Calling %s.%s(%s)',
self.function.__module__,
self.function.__name__,
- utils.log_pprint((obj, create, extracted), kwargs),
+ utils.log_pprint(
+ (obj, create, extraction_context.value),
+ extraction_context.extra,
+ ),
)
- return self.function(obj, create, extracted, **kwargs)
+ return self.function(obj, create,
+ extraction_context.value, **extraction_context.extra)
class RelatedFactory(PostGenerationDeclaration):
@@ -503,20 +518,20 @@ class RelatedFactory(PostGenerationDeclaration):
"""Retrieve the wrapped factory.Factory subclass."""
return self.factory_wrapper.get()
- def call(self, obj, create,
- extracted=None, factory_extracted=False, **kwargs):
+ def call(self, obj, create, extraction_context):
factory = self.get_factory()
- if factory_extracted:
+ if extraction_context.did_extract:
# The user passed in a custom value
logger.debug('RelatedFactory: Using provided %r instead of '
'generating %s.%s.',
- extracted, factory.__module__, factory.__name__,
+ extraction_context.value,
+ factory.__module__, factory.__name__,
)
- return extracted
+ return extraction_context.value
passed_kwargs = dict(self.defaults)
- passed_kwargs.update(kwargs)
+ passed_kwargs.update(extraction_context.extra)
if self.name:
passed_kwargs[self.name] = obj
@@ -547,19 +562,18 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
self.method_args = args
self.method_kwargs = kwargs
- def call(self, obj, create,
- extracted=None, factory_extracted=False, **kwargs):
- if not factory_extracted:
+ def call(self, obj, create, extraction_context):
+ if not extraction_context.did_extract:
passed_args = self.method_args
elif len(self.method_args) <= 1:
# Max one argument expected
- passed_args = (extracted,)
+ passed_args = (extraction_context.value,)
else:
- passed_args = tuple(extracted)
+ passed_args = tuple(extraction_context.value)
passed_kwargs = dict(self.method_kwargs)
- passed_kwargs.update(kwargs)
+ passed_kwargs.update(extraction_context.extra)
method = getattr(obj, self.method_name)
logger.debug('PostGenerationMethodCall: Calling %r.%s(%s)',
obj,
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index 9d54c59..7e7c2fb 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -140,11 +140,11 @@ class PostGenerationDeclarationTestCase(unittest.TestCase):
def test_extract_no_prefix(self):
decl = declarations.PostGenerationDeclaration()
- did_extract, extracted, kwargs = decl.extract('foo',
+ context = decl.extract('foo',
{'foo': 13, 'foo__bar': 42})
- self.assertTrue(did_extract)
- self.assertEqual(extracted, 13)
- self.assertEqual(kwargs, {'bar': 42})
+ self.assertTrue(context.did_extract)
+ self.assertEqual(context.value, 13)
+ self.assertEqual(context.extra, {'bar': 42})
def test_decorator_simple(self):
call_params = []
@@ -153,14 +153,14 @@ class PostGenerationDeclarationTestCase(unittest.TestCase):
call_params.append(args)
call_params.append(kwargs)
- did_extract, extracted, kwargs = foo.extract('foo',
+ context = foo.extract('foo',
{'foo': 13, 'foo__bar': 42, 'blah': 42, 'blah__baz': 1})
- self.assertTrue(did_extract)
- self.assertEqual(13, extracted)
- self.assertEqual({'bar': 42}, kwargs)
+ self.assertTrue(context.did_extract)
+ self.assertEqual(13, context.value)
+ self.assertEqual({'bar': 42}, context.extra)
# No value returned.
- foo.call(None, False, extracted, **kwargs)
+ foo.call(None, False, context)
self.assertEqual(2, len(call_params))
self.assertEqual((None, False, 13), call_params[0])
self.assertEqual({'bar': 42}, call_params[1])
@@ -225,68 +225,73 @@ class PostGenerationMethodCallTestCase(unittest.TestCase):
def setUp(self):
self.obj = mock.MagicMock()
+ def ctx(self, value=None, force_value=False, extra=None):
+ return declarations.ExtractionContext(
+ value,
+ bool(value) or force_value,
+ extra,
+ )
+
def test_simplest_setup_and_call(self):
decl = declarations.PostGenerationMethodCall('method')
- decl.call(self.obj, False)
+ decl.call(self.obj, False, self.ctx())
self.obj.method.assert_called_once_with()
def test_call_with_method_args(self):
decl = declarations.PostGenerationMethodCall(
'method', 'data')
- decl.call(self.obj, False)
+ decl.call(self.obj, False, self.ctx())
self.obj.method.assert_called_once_with('data')
def test_call_with_passed_extracted_string(self):
decl = declarations.PostGenerationMethodCall(
'method')
- decl.call(self.obj, False, 'data', factory_extracted=True)
+ decl.call(self.obj, False, self.ctx('data'))
self.obj.method.assert_called_once_with('data')
def test_call_with_passed_extracted_int(self):
decl = declarations.PostGenerationMethodCall('method')
- decl.call(self.obj, False, 1, factory_extracted=True)
+ decl.call(self.obj, False, self.ctx(1))
self.obj.method.assert_called_once_with(1)
def test_call_with_passed_extracted_iterable(self):
decl = declarations.PostGenerationMethodCall('method')
- decl.call(self.obj, False, (1, 2, 3), factory_extracted=True)
+ decl.call(self.obj, False, self.ctx((1, 2, 3)))
self.obj.method.assert_called_once_with((1, 2, 3))
def test_call_with_method_kwargs(self):
decl = declarations.PostGenerationMethodCall(
'method', data='data')
- decl.call(self.obj, False)
+ decl.call(self.obj, False, self.ctx())
self.obj.method.assert_called_once_with(data='data')
def test_call_with_passed_kwargs(self):
decl = declarations.PostGenerationMethodCall('method')
- decl.call(self.obj, False, data='other')
+ decl.call(self.obj, False, self.ctx(extra={'data': 'other'}))
self.obj.method.assert_called_once_with(data='other')
def test_multi_call_with_multi_method_args(self):
decl = declarations.PostGenerationMethodCall(
'method', 'arg1', 'arg2')
- decl.call(self.obj, False)
+ decl.call(self.obj, False, self.ctx())
self.obj.method.assert_called_once_with('arg1', 'arg2')
def test_multi_call_with_passed_multiple_args(self):
decl = declarations.PostGenerationMethodCall(
'method', 'arg1', 'arg2')
- decl.call(self.obj, False, ('param1', 'param2', 'param3'),
- factory_extracted=True)
+ decl.call(self.obj, False, self.ctx(('param1', 'param2', 'param3')))
self.obj.method.assert_called_once_with('param1', 'param2', 'param3')
def test_multi_call_with_passed_tuple(self):
decl = declarations.PostGenerationMethodCall(
'method', 'arg1', 'arg2')
- decl.call(self.obj, False, (('param1', 'param2'),),
- factory_extracted=True)
+ decl.call(self.obj, False, self.ctx((('param1', 'param2'),)))
self.obj.method.assert_called_once_with(('param1', 'param2'))
def test_multi_call_with_kwargs(self):
decl = declarations.PostGenerationMethodCall(
'method', 'arg1', 'arg2')
- decl.call(self.obj, False, x=2)
+ decl.call(self.obj, False, self.ctx(extra={'x': 2}))
self.obj.method.assert_called_once_with('arg1', 'arg2', x=2)