diff options
-rw-r--r-- | factory/declarations.py | 10 | ||||
-rw-r--r-- | tests/test_declarations.py | 47 |
2 files changed, 56 insertions, 1 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index 366c2c8..1f1d2af 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -21,6 +21,7 @@ # THE SOFTWARE. +import collections import itertools import warnings @@ -498,10 +499,17 @@ class PostGenerationMethodCall(PostGenerationDeclaration): self.method_kwargs = kwargs def call(self, obj, create, extracted=None, **kwargs): + if extracted is not None: + passed_args = extracted + if isinstance(passed_args, basestring) or ( + not isinstance(passed_args, collections.Iterable)): + passed_args = (passed_args,) + else: + passed_args = self.method_args passed_kwargs = dict(self.method_kwargs) passed_kwargs.update(kwargs) method = getattr(obj, self.method_name) - method(*self.method_args, **passed_kwargs) + method(*passed_args, **passed_kwargs) # Decorators... in case lambdas don't cut it diff --git a/tests/test_declarations.py b/tests/test_declarations.py index cc921d4..59a3955 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -24,6 +24,8 @@ import datetime import itertools import warnings +from mock import MagicMock + from factory import declarations from .compat import unittest @@ -295,6 +297,51 @@ class RelatedFactoryTestCase(unittest.TestCase): datetime.date = orig_date +class PostGenerationMethodCallTestCase(unittest.TestCase): + def setUp(self): + self.obj = MagicMock() + + def test_simplest_setup_and_call(self): + decl = declarations.PostGenerationMethodCall('method') + decl.call(self.obj, False) + self.obj.method.assert_called_once_with() + + def test_call_with_method_args(self): + decl = declarations.PostGenerationMethodCall( + 'method', None, 'data') + decl.call(self.obj, False) + self.obj.method.assert_called_once_with('data') + + def test_call_with_passed_extracted_string(self): + decl = declarations.PostGenerationMethodCall( + 'method', None) + decl.call(self.obj, False, 'data') + self.obj.method.assert_called_once_with('data') + + def test_call_with_passed_extracted_int(self): + decl = declarations.PostGenerationMethodCall('method') + decl.call(self.obj, False, 1) + self.obj.method.assert_called_once_with(1) + + def test_call_with_passed_extracted_iterable(self): + decl = declarations.PostGenerationMethodCall('method') + decl.call(self.obj, False, (1, 2, 3)) + self.obj.method.assert_called_once_with(1, 2, 3) + + def test_call_with_method_kwargs(self): + decl = declarations.PostGenerationMethodCall( + 'method', None, data='data') + decl.call(self.obj, False) + self.obj.method.assert_called_once_with(data='data') + + def test_call_with_passed_kwargs(self): + decl = declarations.PostGenerationMethodCall('method') + decl.call(self.obj, False, data='other') + self.obj.method.assert_called_once_with(data='other') + + + + class CircularSubFactoryTestCase(unittest.TestCase): def test_circularsubfactory_deprecated(self): |