diff options
-rw-r--r-- | factory/base.py | 5 | ||||
-rw-r--r-- | factory/declarations.py | 60 | ||||
-rw-r--r-- | tests/test_declarations.py | 49 |
3 files changed, 66 insertions, 48 deletions
diff --git a/factory/base.py b/factory/base.py index 76d3d4a..0429231 100644 --- a/factory/base.py +++ b/factory/base.py @@ -380,9 +380,8 @@ class BaseFactory(object): # Handle post-generation attributes results = {} for name, decl in sorted(postgen_declarations.items()): - did_extract, extracted, extracted_kwargs = postgen_attributes[name] - results[name] = decl.call(obj, create, extracted, - factory_extracted=did_extract, **extracted_kwargs) + extraction_context = postgen_attributes[name] + results[name] = decl.call(obj, create, extraction_context) cls._after_postgeneration(obj, create, results) diff --git a/factory/declarations.py b/factory/declarations.py index 552ddf2..e7c8ed5 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -417,6 +417,21 @@ class List(SubFactory): **params) +class ExtractionContext(object): + """Private class holding all required context from extraction to postgen.""" + def __init__(self, value=None, did_extract=False, extra=None): + self.value = value + self.did_extract = did_extract + self.extra = extra or {} + + def __repr__(self): + return 'ExtractionContext(%r, %r, %r)' % ( + self.value, + self.did_extract, + self.extra, + ) + + class PostGenerationDeclaration(object): """Declarations to be called once the target object has been generated.""" @@ -441,19 +456,16 @@ class PostGenerationDeclaration(object): did_extract = False kwargs = utils.extract_dict(name, attrs) - return did_extract, extracted, kwargs + return ExtractionContext(extracted, did_extract, kwargs) - def call(self, obj, create, - extracted=None, factory_extracted=False, **kwargs): # pragma: no cover + def call(self, obj, create, extraction_context): # pragma: no cover """Call this hook; no return value is expected. Args: obj (object): the newly generated object create (bool): whether the object was 'built' or 'created' - extracted (object): the value given for <name> in the - object definition, or None if not provided. - kwargs (dict): declarations extracted from the object - definition for this hook + extraction_context: An ExtractionContext containing values + extracted from the containing factory's declaration """ raise NotImplementedError() @@ -464,14 +476,17 @@ class PostGeneration(PostGenerationDeclaration): super(PostGeneration, self).__init__() self.function = function - def call(self, obj, create, - extracted=None, factory_extracted=False, **kwargs): + def call(self, obj, create, extraction_context): logger.debug('PostGeneration: Calling %s.%s(%s)', self.function.__module__, self.function.__name__, - utils.log_pprint((obj, create, extracted), kwargs), + utils.log_pprint( + (obj, create, extraction_context.value), + extraction_context.extra, + ), ) - return self.function(obj, create, extracted, **kwargs) + return self.function(obj, create, + extraction_context.value, **extraction_context.extra) class RelatedFactory(PostGenerationDeclaration): @@ -503,20 +518,20 @@ class RelatedFactory(PostGenerationDeclaration): """Retrieve the wrapped factory.Factory subclass.""" return self.factory_wrapper.get() - def call(self, obj, create, - extracted=None, factory_extracted=False, **kwargs): + def call(self, obj, create, extraction_context): factory = self.get_factory() - if factory_extracted: + if extraction_context.did_extract: # The user passed in a custom value logger.debug('RelatedFactory: Using provided %r instead of ' 'generating %s.%s.', - extracted, factory.__module__, factory.__name__, + extraction_context.value, + factory.__module__, factory.__name__, ) - return extracted + return extraction_context.value passed_kwargs = dict(self.defaults) - passed_kwargs.update(kwargs) + passed_kwargs.update(extraction_context.extra) if self.name: passed_kwargs[self.name] = obj @@ -547,19 +562,18 @@ class PostGenerationMethodCall(PostGenerationDeclaration): self.method_args = args self.method_kwargs = kwargs - def call(self, obj, create, - extracted=None, factory_extracted=False, **kwargs): - if not factory_extracted: + def call(self, obj, create, extraction_context): + if not extraction_context.did_extract: passed_args = self.method_args elif len(self.method_args) <= 1: # Max one argument expected - passed_args = (extracted,) + passed_args = (extraction_context.value,) else: - passed_args = tuple(extracted) + passed_args = tuple(extraction_context.value) passed_kwargs = dict(self.method_kwargs) - passed_kwargs.update(kwargs) + passed_kwargs.update(extraction_context.extra) method = getattr(obj, self.method_name) logger.debug('PostGenerationMethodCall: Calling %r.%s(%s)', obj, diff --git a/tests/test_declarations.py b/tests/test_declarations.py index 9d54c59..7e7c2fb 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -140,11 +140,11 @@ class PostGenerationDeclarationTestCase(unittest.TestCase): def test_extract_no_prefix(self): decl = declarations.PostGenerationDeclaration() - did_extract, extracted, kwargs = decl.extract('foo', + context = decl.extract('foo', {'foo': 13, 'foo__bar': 42}) - self.assertTrue(did_extract) - self.assertEqual(extracted, 13) - self.assertEqual(kwargs, {'bar': 42}) + self.assertTrue(context.did_extract) + self.assertEqual(context.value, 13) + self.assertEqual(context.extra, {'bar': 42}) def test_decorator_simple(self): call_params = [] @@ -153,14 +153,14 @@ class PostGenerationDeclarationTestCase(unittest.TestCase): call_params.append(args) call_params.append(kwargs) - did_extract, extracted, kwargs = foo.extract('foo', + context = foo.extract('foo', {'foo': 13, 'foo__bar': 42, 'blah': 42, 'blah__baz': 1}) - self.assertTrue(did_extract) - self.assertEqual(13, extracted) - self.assertEqual({'bar': 42}, kwargs) + self.assertTrue(context.did_extract) + self.assertEqual(13, context.value) + self.assertEqual({'bar': 42}, context.extra) # No value returned. - foo.call(None, False, extracted, **kwargs) + foo.call(None, False, context) self.assertEqual(2, len(call_params)) self.assertEqual((None, False, 13), call_params[0]) self.assertEqual({'bar': 42}, call_params[1]) @@ -225,68 +225,73 @@ class PostGenerationMethodCallTestCase(unittest.TestCase): def setUp(self): self.obj = mock.MagicMock() + def ctx(self, value=None, force_value=False, extra=None): + return declarations.ExtractionContext( + value, + bool(value) or force_value, + extra, + ) + def test_simplest_setup_and_call(self): decl = declarations.PostGenerationMethodCall('method') - decl.call(self.obj, False) + decl.call(self.obj, False, self.ctx()) self.obj.method.assert_called_once_with() def test_call_with_method_args(self): decl = declarations.PostGenerationMethodCall( 'method', 'data') - decl.call(self.obj, False) + decl.call(self.obj, False, self.ctx()) self.obj.method.assert_called_once_with('data') def test_call_with_passed_extracted_string(self): decl = declarations.PostGenerationMethodCall( 'method') - decl.call(self.obj, False, 'data', factory_extracted=True) + decl.call(self.obj, False, self.ctx('data')) self.obj.method.assert_called_once_with('data') def test_call_with_passed_extracted_int(self): decl = declarations.PostGenerationMethodCall('method') - decl.call(self.obj, False, 1, factory_extracted=True) + decl.call(self.obj, False, self.ctx(1)) self.obj.method.assert_called_once_with(1) def test_call_with_passed_extracted_iterable(self): decl = declarations.PostGenerationMethodCall('method') - decl.call(self.obj, False, (1, 2, 3), factory_extracted=True) + decl.call(self.obj, False, self.ctx((1, 2, 3))) self.obj.method.assert_called_once_with((1, 2, 3)) def test_call_with_method_kwargs(self): decl = declarations.PostGenerationMethodCall( 'method', data='data') - decl.call(self.obj, False) + decl.call(self.obj, False, self.ctx()) self.obj.method.assert_called_once_with(data='data') def test_call_with_passed_kwargs(self): decl = declarations.PostGenerationMethodCall('method') - decl.call(self.obj, False, data='other') + decl.call(self.obj, False, self.ctx(extra={'data': 'other'})) self.obj.method.assert_called_once_with(data='other') def test_multi_call_with_multi_method_args(self): decl = declarations.PostGenerationMethodCall( 'method', 'arg1', 'arg2') - decl.call(self.obj, False) + decl.call(self.obj, False, self.ctx()) self.obj.method.assert_called_once_with('arg1', 'arg2') def test_multi_call_with_passed_multiple_args(self): decl = declarations.PostGenerationMethodCall( 'method', 'arg1', 'arg2') - decl.call(self.obj, False, ('param1', 'param2', 'param3'), - factory_extracted=True) + decl.call(self.obj, False, self.ctx(('param1', 'param2', 'param3'))) self.obj.method.assert_called_once_with('param1', 'param2', 'param3') def test_multi_call_with_passed_tuple(self): decl = declarations.PostGenerationMethodCall( 'method', 'arg1', 'arg2') - decl.call(self.obj, False, (('param1', 'param2'),), - factory_extracted=True) + decl.call(self.obj, False, self.ctx((('param1', 'param2'),))) self.obj.method.assert_called_once_with(('param1', 'param2')) def test_multi_call_with_kwargs(self): decl = declarations.PostGenerationMethodCall( 'method', 'arg1', 'arg2') - decl.call(self.obj, False, x=2) + decl.call(self.obj, False, self.ctx(extra={'x': 2})) self.obj.method.assert_called_once_with('arg1', 'arg2', x=2) |